[language_classifier] Language classifier data set I/O
This commit is contained in:
199
src/language_classifier_io.c
Normal file
199
src/language_classifier_io.c
Normal file
@@ -0,0 +1,199 @@
|
||||
#include "language_classifier_io.h"
|
||||
|
||||
#include "log/log.h"
|
||||
|
||||
#include "constants.h"
|
||||
#include "collections.h"
|
||||
#include "language_features.h"
|
||||
|
||||
language_classifier_data_set_t *language_classifier_data_set_init(char *filename) {
|
||||
language_classifier_data_set_t *data_set = malloc(sizeof(language_classifier_data_set_t));
|
||||
data_set->f = fopen(filename, "r");
|
||||
if (data_set->f == NULL) {
|
||||
free(data_set);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
data_set->tokens = token_array_new();
|
||||
data_set->feature_array = char_array_new();
|
||||
data_set->address = char_array_new();
|
||||
data_set->language = char_array_new_size(MAX_LANGUAGE_LEN);
|
||||
data_set->country = char_array_new_size(MAX_COUNTRY_CODE_LEN);
|
||||
|
||||
return data_set;
|
||||
}
|
||||
|
||||
|
||||
bool language_classifier_data_set_next(language_classifier_data_set_t *self) {
|
||||
if (self == NULL) return false;
|
||||
|
||||
char *line = file_getline(self->f);
|
||||
if (line == NULL) {
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t token_count;
|
||||
|
||||
cstring_array *fields = cstring_array_split(line, TAB_SEPARATOR, TAB_SEPARATOR_LEN, &token_count);
|
||||
|
||||
free(line);
|
||||
|
||||
if (token_count != LANGUAGE_CLASSIFIER_FILE_NUM_TOKENS) {
|
||||
log_error("Token count did not match, ected %d, got %zu\n", LANGUAGE_CLASSIFIER_FILE_NUM_TOKENS, token_count);
|
||||
}
|
||||
|
||||
char *language = cstring_array_get_string(fields, LANGUAGE_CLASSIFIER_FIELD_LANGUAGE);
|
||||
char *country = cstring_array_get_string(fields, LANGUAGE_CLASSIFIER_FIELD_COUNTRY);
|
||||
char *address = cstring_array_get_string(fields, LANGUAGE_CLASSIFIER_FIELD_ADDRESS);
|
||||
|
||||
log_debug("Doing: %s\n", address);
|
||||
|
||||
char *normalized = language_classifier_normalize_string(address);
|
||||
bool is_normalized = normalized != NULL;
|
||||
if (!is_normalized) {
|
||||
log_debug("could not normalize\n");
|
||||
normalized = strdup(address);
|
||||
}
|
||||
|
||||
char_array_clear(self->address);
|
||||
char_array_add(self->address, normalized);
|
||||
|
||||
char_array_clear(self->country);
|
||||
char_array_add(self->country, country);
|
||||
|
||||
char_array_clear(self->language);
|
||||
char_array_add(self->language, language);
|
||||
|
||||
cstring_array_destroy(fields);
|
||||
bool ret = normalized != NULL;
|
||||
free(normalized);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
void language_classifier_minibatch_destroy(language_classifier_minibatch_t *self) {
|
||||
if (self == NULL) return;
|
||||
|
||||
size_t i;
|
||||
|
||||
if (self->features != NULL) {
|
||||
for (i = 0; i < self->features->n; i++) {
|
||||
khash_t(str_double) *feature_counts = self->features->a[i];
|
||||
const char *feature;
|
||||
|
||||
kh_foreach_key(feature_counts, feature, {
|
||||
free((char *)feature);
|
||||
})
|
||||
|
||||
kh_destroy(str_double, feature_counts);
|
||||
}
|
||||
feature_count_array_destroy(self->features);
|
||||
|
||||
}
|
||||
|
||||
if (self->labels != NULL) {
|
||||
cstring_array_destroy(self->labels);
|
||||
}
|
||||
|
||||
free(self);
|
||||
}
|
||||
|
||||
language_classifier_minibatch_t *language_classifier_minibatch_new(void) {
|
||||
language_classifier_minibatch_t *minibatch = malloc(sizeof(language_classifier_minibatch_t));
|
||||
if (minibatch == NULL) return NULL;
|
||||
|
||||
minibatch->features = feature_count_array_new();
|
||||
if (minibatch->features == NULL) {
|
||||
language_classifier_minibatch_destroy(minibatch);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
minibatch->labels = cstring_array_new();
|
||||
if (minibatch->labels == NULL) {
|
||||
language_classifier_minibatch_destroy(minibatch);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return minibatch;
|
||||
}
|
||||
|
||||
inline bool language_classifier_language_is_valid(char *language) {
|
||||
return !string_equals(language, AMBIGUOUS_LANGUAGE) && !string_equals(language, UNKNOWN_LANGUAGE);
|
||||
}
|
||||
|
||||
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, size_t batch_size, bool with_country) {
|
||||
size_t in_batch = 0;
|
||||
|
||||
language_classifier_minibatch_t *minibatch = NULL;
|
||||
|
||||
while (in_batch < batch_size && language_classifier_data_set_next(self)) {
|
||||
char *address = char_array_get_string(self->address);
|
||||
if (strlen(address) == 0) {
|
||||
continue;
|
||||
}
|
||||
char *country = NULL;
|
||||
|
||||
if (with_country) {
|
||||
country = char_array_get_string(self->country);
|
||||
}
|
||||
|
||||
char *language = char_array_get_string(self->language);
|
||||
if (!language_classifier_language_is_valid(language)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (minibatch == NULL) {
|
||||
minibatch = language_classifier_minibatch_new();
|
||||
if (minibatch == NULL) {
|
||||
log_error("Error creating minibatch\n");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
khash_t(str_double) *feature_counts = extract_language_features(address, country, self->tokens, self->feature_array);
|
||||
if (feature_counts == NULL) {
|
||||
log_error("Could not extract features for: %s\n", address);
|
||||
language_classifier_minibatch_destroy(minibatch);
|
||||
return NULL;
|
||||
}
|
||||
feature_count_array_push(minibatch->features, feature_counts);
|
||||
cstring_array_add_string(minibatch->labels, language);
|
||||
in_batch++;
|
||||
}
|
||||
|
||||
return minibatch;
|
||||
}
|
||||
|
||||
inline language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, bool with_country) {
|
||||
return language_classifier_data_set_get_minibatch_with_size(self, LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE, with_country);
|
||||
}
|
||||
|
||||
void language_classifier_data_set_destroy(language_classifier_data_set_t *self) {
|
||||
if (self == NULL) return;
|
||||
|
||||
if (self->f != NULL) {
|
||||
fclose(self->f);
|
||||
}
|
||||
|
||||
if (self->tokens != NULL) {
|
||||
token_array_destroy(self->tokens);
|
||||
}
|
||||
|
||||
if (self->feature_array != NULL) {
|
||||
char_array_destroy(self->feature_array);
|
||||
}
|
||||
|
||||
if (self->address != NULL) {
|
||||
char_array_destroy(self->address);
|
||||
}
|
||||
|
||||
if (self->language != NULL) {
|
||||
char_array_destroy(self->language);
|
||||
}
|
||||
|
||||
if (self->country != NULL) {
|
||||
char_array_destroy(self->country);
|
||||
}
|
||||
|
||||
free(self);
|
||||
}
|
||||
Reference in New Issue
Block a user