[language_classifier] Language classifier structs, I/O and API
This commit is contained in:
279
src/language_classifier.c
Normal file
279
src/language_classifier.c
Normal file
@@ -0,0 +1,279 @@
|
||||
#include "language_classifier.h"
|
||||
|
||||
#include <float.h>
|
||||
|
||||
#include "language_features.h"
|
||||
#include "minibatch.h"
|
||||
#include "normalize.h"
|
||||
#include "token_types.h"
|
||||
#include "unicode_scripts.h"
|
||||
|
||||
#define LANGUAGE_CLASSIFIER_SIGNATURE 0xCCCCCCCC
|
||||
|
||||
#define MIN_PROB (0.05 - DBL_EPSILON)
|
||||
|
||||
static language_classifier_t *language_classifier = NULL;
|
||||
static language_classifier_t *language_classifier_country = NULL;
|
||||
|
||||
void language_classifier_destroy(language_classifier_t *self) {
|
||||
if (self == NULL) return;
|
||||
|
||||
if (self->features != NULL) {
|
||||
trie_destroy(self->features);
|
||||
}
|
||||
|
||||
if (self->labels != NULL) {
|
||||
cstring_array_destroy(self->labels);
|
||||
}
|
||||
|
||||
if (self->weights != NULL) {
|
||||
matrix_destroy(self->weights);
|
||||
}
|
||||
|
||||
free(self);
|
||||
}
|
||||
|
||||
language_classifier_t *language_classifier_new(void) {
|
||||
language_classifier_t *language_classifier = malloc(sizeof(language_classifier_t));
|
||||
return language_classifier;
|
||||
}
|
||||
|
||||
language_classifier_t *get_language_classifier(void) {
|
||||
return language_classifier;
|
||||
}
|
||||
|
||||
language_classifier_t *get_language_classifier_country(void) {
|
||||
return language_classifier_country;
|
||||
}
|
||||
|
||||
language_classifier_response_t *classify_languages(char *address, char *country) {
|
||||
language_classifier_t *classifier = NULL;
|
||||
|
||||
if (country == NULL) {
|
||||
classifier = get_language_classifier();
|
||||
} else {
|
||||
classifier = get_language_classifier_country();
|
||||
}
|
||||
|
||||
if (classifier == NULL) {
|
||||
log_error("classifier NULL\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
char *normalized = language_classifier_normalize_string(address);
|
||||
|
||||
token_array *tokens = token_array_new();
|
||||
char_array *feature_array = char_array_new();
|
||||
|
||||
khash_t(str_double) *feature_counts = extract_language_features(normalized, country, tokens, feature_array);
|
||||
|
||||
sparse_matrix_t *x = feature_vector(classifier->features, feature_counts);
|
||||
|
||||
size_t n = classifier->num_labels;
|
||||
matrix_t *p_y = matrix_new_zeros(1, n);
|
||||
|
||||
language_classifier_response_t *response = NULL;
|
||||
if (logistic_regression_model_expectation(classifier->weights, x, p_y)) {
|
||||
double *predictions = matrix_get_row(p_y, 0);
|
||||
size_t *indices = double_array_argsort(predictions, n);
|
||||
size_t num_languages = 0;
|
||||
size_t i;
|
||||
double prob;
|
||||
|
||||
double min_prob = 1.0 / n;
|
||||
if (min_prob < MIN_PROB) min_prob = MIN_PROB;
|
||||
|
||||
for (i = 0; i < n; i++) {
|
||||
size_t idx = indices[n - i - 1];
|
||||
prob = predictions[idx];
|
||||
|
||||
if (i == 0 || prob > min_prob) {
|
||||
num_languages++;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
char **languages = malloc(sizeof(char *) * num_languages);
|
||||
double *probs = malloc(sizeof(double) * num_languages);
|
||||
|
||||
for (i = 0; i < num_languages; i++) {
|
||||
size_t idx = indices[n - i - 1];
|
||||
char *lang = cstring_array_get_string(classifier->labels, (uint32_t)idx);
|
||||
prob = predictions[idx];
|
||||
languages[i] = lang;
|
||||
probs[i] = prob;
|
||||
}
|
||||
|
||||
free(indices);
|
||||
|
||||
response = malloc(sizeof(language_classifier_response_t));
|
||||
response->num_languages = num_languages;
|
||||
response->languages = languages;
|
||||
response->probs = probs;
|
||||
}
|
||||
|
||||
exit_tokens_created:
|
||||
token_array_destroy(tokens);
|
||||
char_array_destroy(feature_array);
|
||||
const char *key;
|
||||
kh_foreach_key(feature_counts, key, {
|
||||
free((char *)key);
|
||||
})
|
||||
kh_destroy(str_double, feature_counts);
|
||||
return response;
|
||||
|
||||
}
|
||||
|
||||
language_classifier_t *language_classifier_read(FILE *f) {
|
||||
if (f == NULL) return NULL;
|
||||
long save_pos = ftell(f);
|
||||
|
||||
uint32_t signature;
|
||||
|
||||
if (!file_read_uint32(f, &signature) || signature != LANGUAGE_CLASSIFIER_SIGNATURE) {
|
||||
goto exit_file_read;
|
||||
}
|
||||
|
||||
language_classifier_t *classifier = language_classifier_new();
|
||||
if (classifier == NULL) {
|
||||
goto exit_file_read;
|
||||
}
|
||||
|
||||
trie_t *features = trie_read(f);
|
||||
if (features == NULL) {
|
||||
goto exit_classifier_created;
|
||||
}
|
||||
classifier->features = features;
|
||||
uint64_t num_features;
|
||||
if (!file_read_uint64(f, &num_features)) {
|
||||
goto exit_classifier_created;
|
||||
}
|
||||
classifier->num_features = (size_t)num_features;
|
||||
|
||||
uint64_t labels_str_len;
|
||||
|
||||
if (!file_read_uint64(f, &labels_str_len)) {
|
||||
goto exit_classifier_created;
|
||||
}
|
||||
|
||||
char_array *array = char_array_new_size(labels_str_len);
|
||||
|
||||
if (array == NULL) {
|
||||
goto exit_classifier_created;
|
||||
}
|
||||
|
||||
if (!file_read_chars(f, array->a, labels_str_len)) {
|
||||
char_array_destroy(array);
|
||||
goto exit_classifier_created;
|
||||
}
|
||||
|
||||
array->n = labels_str_len;
|
||||
|
||||
classifier->labels = cstring_array_from_char_array(array);
|
||||
if (classifier->labels == NULL) {
|
||||
goto exit_classifier_created;
|
||||
}
|
||||
classifier->num_labels = cstring_array_num_strings(classifier->labels);
|
||||
|
||||
matrix_t *weights = matrix_read(f);
|
||||
|
||||
if (weights == NULL) {
|
||||
goto exit_classifier_created;
|
||||
}
|
||||
|
||||
classifier->weights = weights;
|
||||
|
||||
return classifier;
|
||||
|
||||
exit_classifier_created:
|
||||
language_classifier_destroy(classifier);
|
||||
exit_file_read:
|
||||
fseek(f, save_pos, SEEK_SET);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
||||
language_classifier_t *language_classifier_load(char *path) {
|
||||
FILE *f;
|
||||
|
||||
f = fopen(path, "rb");
|
||||
if (!f) return NULL;
|
||||
|
||||
language_classifier_t *classifier = language_classifier_read(f);
|
||||
|
||||
fclose(f);
|
||||
return classifier;
|
||||
}
|
||||
|
||||
bool language_classifier_write(language_classifier_t *self, FILE *f) {
|
||||
if (f == NULL || self == NULL) return false;
|
||||
|
||||
if (!file_write_uint32(f, LANGUAGE_CLASSIFIER_SIGNATURE) ||
|
||||
!trie_write(self->features, f) ||
|
||||
!file_write_uint64(f, self->num_features) ||
|
||||
!file_write_uint64(f, self->labels->str->n) ||
|
||||
!file_write_chars(f, (const char *)self->labels->str->a, self->labels->str->n) ||
|
||||
!matrix_write(self->weights, f)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool language_classifier_save(language_classifier_t *self, char *path) {
|
||||
if (self == NULL || path == NULL) return false;
|
||||
|
||||
FILE *f = fopen(path, "wb");
|
||||
if (!f) return false;
|
||||
|
||||
bool result = language_classifier_write(self, f);
|
||||
fclose(f);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Module setup/teardown
|
||||
|
||||
bool language_classifier_module_setup(char *dir) {
|
||||
if (language_classifier != NULL && language_classifier_country != NULL) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (dir == NULL) {
|
||||
dir = LIBPOSTAL_LANGUAGE_CLASSIFIER_DIR;
|
||||
}
|
||||
|
||||
char *classifier_path;
|
||||
|
||||
char_array *path = char_array_new_size(strlen(dir) + PATH_SEPARATOR_LEN + strlen(LANGUAGE_CLASSIFIER_COUNTRY_FILENAME));
|
||||
if (language_classifier == NULL) {
|
||||
char_array_cat_joined(path, PATH_SEPARATOR, true, 2, dir, LANGUAGE_CLASSIFIER_FILENAME);
|
||||
classifier_path = char_array_get_string(path);
|
||||
|
||||
language_classifier = language_classifier_load(classifier_path);
|
||||
|
||||
}
|
||||
|
||||
if (language_classifier_country == NULL) {
|
||||
char_array_clear(path);
|
||||
char_array_cat_joined(path, PATH_SEPARATOR, true, 2, dir, LANGUAGE_CLASSIFIER_COUNTRY_FILENAME);
|
||||
classifier_path = char_array_get_string(path);
|
||||
|
||||
language_classifier_country = language_classifier_load(classifier_path);
|
||||
|
||||
}
|
||||
|
||||
char_array_destroy(path);
|
||||
return true;
|
||||
}
|
||||
|
||||
void language_classifier_module_teardown(void) {
|
||||
if (language_classifier != NULL) {
|
||||
language_classifier_destroy(language_classifier);
|
||||
}
|
||||
|
||||
if (language_classifier_country != NULL) {
|
||||
language_classifier_destroy(language_classifier_country);
|
||||
}
|
||||
}
|
||||
|
||||
55
src/language_classifier.h
Normal file
55
src/language_classifier.h
Normal file
@@ -0,0 +1,55 @@
|
||||
#ifndef LANGUAGE_CLASSIFIER_H
|
||||
#define LANGUAGE_CLASSIFIER_H
|
||||
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "collections.h"
|
||||
#include "language_features.h"
|
||||
#include "logistic_regression.h"
|
||||
#include "tokens.h"
|
||||
#include "string_utils.h"
|
||||
#include "trie.h"
|
||||
|
||||
#define LANGUAGE_CLASSIFIER_FILENAME "language_classifier.dat"
|
||||
#define LANGUAGE_CLASSIFIER_COUNTRY_FILENAME "language_classifier_country.dat"
|
||||
|
||||
typedef struct language_classifier {
|
||||
size_t num_labels;
|
||||
size_t num_features;
|
||||
trie_t *features;
|
||||
cstring_array *labels;
|
||||
matrix_t *weights;
|
||||
} language_classifier_t;
|
||||
|
||||
|
||||
typedef struct language_classifier_response {
|
||||
size_t num_languages;
|
||||
char **languages;
|
||||
double *probs;
|
||||
} language_classifier_response_t;
|
||||
|
||||
// General usage
|
||||
|
||||
language_classifier_t *language_classifier_new(void);
|
||||
language_classifier_t *get_language_classifier(void);
|
||||
language_classifier_t *get_language_classifier_country(void);
|
||||
|
||||
language_classifier_response_t *classify_languages(char *address, char *country);
|
||||
|
||||
void language_classifier_destroy(language_classifier_t *self);
|
||||
|
||||
// I/O methods
|
||||
|
||||
language_classifier_t *language_classifier_load(char *path);
|
||||
bool language_classifier_save(language_classifier_t *self, char *output_dir);
|
||||
|
||||
// Module setup/teardown
|
||||
|
||||
bool language_classifier_module_setup(char *dir);
|
||||
void language_classifier_module_teardown(void);
|
||||
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user