diff --git a/src/language_classifier_io.c b/src/language_classifier_io.c new file mode 100644 index 00000000..6be0d84d --- /dev/null +++ b/src/language_classifier_io.c @@ -0,0 +1,199 @@ +#include "language_classifier_io.h" + +#include "log/log.h" + +#include "constants.h" +#include "collections.h" +#include "language_features.h" + +language_classifier_data_set_t *language_classifier_data_set_init(char *filename) { + language_classifier_data_set_t *data_set = malloc(sizeof(language_classifier_data_set_t)); + data_set->f = fopen(filename, "r"); + if (data_set->f == NULL) { + free(data_set); + return NULL; + } + + data_set->tokens = token_array_new(); + data_set->feature_array = char_array_new(); + data_set->address = char_array_new(); + data_set->language = char_array_new_size(MAX_LANGUAGE_LEN); + data_set->country = char_array_new_size(MAX_COUNTRY_CODE_LEN); + + return data_set; +} + + +bool language_classifier_data_set_next(language_classifier_data_set_t *self) { + if (self == NULL) return false; + + char *line = file_getline(self->f); + if (line == NULL) { + return false; + } + + size_t token_count; + + cstring_array *fields = cstring_array_split(line, TAB_SEPARATOR, TAB_SEPARATOR_LEN, &token_count); + + free(line); + + if (token_count != LANGUAGE_CLASSIFIER_FILE_NUM_TOKENS) { + log_error("Token count did not match, ected %d, got %zu\n", LANGUAGE_CLASSIFIER_FILE_NUM_TOKENS, token_count); + } + + char *language = cstring_array_get_string(fields, LANGUAGE_CLASSIFIER_FIELD_LANGUAGE); + char *country = cstring_array_get_string(fields, LANGUAGE_CLASSIFIER_FIELD_COUNTRY); + char *address = cstring_array_get_string(fields, LANGUAGE_CLASSIFIER_FIELD_ADDRESS); + + log_debug("Doing: %s\n", address); + + char *normalized = language_classifier_normalize_string(address); + bool is_normalized = normalized != NULL; + if (!is_normalized) { + log_debug("could not normalize\n"); + normalized = strdup(address); + } + + char_array_clear(self->address); + char_array_add(self->address, normalized); + + char_array_clear(self->country); + char_array_add(self->country, country); + + char_array_clear(self->language); + char_array_add(self->language, language); + + cstring_array_destroy(fields); + bool ret = normalized != NULL; + free(normalized); + + return ret; +} + +void language_classifier_minibatch_destroy(language_classifier_minibatch_t *self) { + if (self == NULL) return; + + size_t i; + + if (self->features != NULL) { + for (i = 0; i < self->features->n; i++) { + khash_t(str_double) *feature_counts = self->features->a[i]; + const char *feature; + + kh_foreach_key(feature_counts, feature, { + free((char *)feature); + }) + + kh_destroy(str_double, feature_counts); + } + feature_count_array_destroy(self->features); + + } + + if (self->labels != NULL) { + cstring_array_destroy(self->labels); + } + + free(self); +} + +language_classifier_minibatch_t *language_classifier_minibatch_new(void) { + language_classifier_minibatch_t *minibatch = malloc(sizeof(language_classifier_minibatch_t)); + if (minibatch == NULL) return NULL; + + minibatch->features = feature_count_array_new(); + if (minibatch->features == NULL) { + language_classifier_minibatch_destroy(minibatch); + return NULL; + } + + minibatch->labels = cstring_array_new(); + if (minibatch->labels == NULL) { + language_classifier_minibatch_destroy(minibatch); + return NULL; + } + + return minibatch; +} + +inline bool language_classifier_language_is_valid(char *language) { + return !string_equals(language, AMBIGUOUS_LANGUAGE) && !string_equals(language, UNKNOWN_LANGUAGE); +} + +language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, size_t batch_size, bool with_country) { + size_t in_batch = 0; + + language_classifier_minibatch_t *minibatch = NULL; + + while (in_batch < batch_size && language_classifier_data_set_next(self)) { + char *address = char_array_get_string(self->address); + if (strlen(address) == 0) { + continue; + } + char *country = NULL; + + if (with_country) { + country = char_array_get_string(self->country); + } + + char *language = char_array_get_string(self->language); + if (!language_classifier_language_is_valid(language)) { + continue; + } + + if (minibatch == NULL) { + minibatch = language_classifier_minibatch_new(); + if (minibatch == NULL) { + log_error("Error creating minibatch\n"); + return NULL; + } + } + + khash_t(str_double) *feature_counts = extract_language_features(address, country, self->tokens, self->feature_array); + if (feature_counts == NULL) { + log_error("Could not extract features for: %s\n", address); + language_classifier_minibatch_destroy(minibatch); + return NULL; + } + feature_count_array_push(minibatch->features, feature_counts); + cstring_array_add_string(minibatch->labels, language); + in_batch++; + } + + return minibatch; +} + +inline language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, bool with_country) { + return language_classifier_data_set_get_minibatch_with_size(self, LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE, with_country); +} + +void language_classifier_data_set_destroy(language_classifier_data_set_t *self) { + if (self == NULL) return; + + if (self->f != NULL) { + fclose(self->f); + } + + if (self->tokens != NULL) { + token_array_destroy(self->tokens); + } + + if (self->feature_array != NULL) { + char_array_destroy(self->feature_array); + } + + if (self->address != NULL) { + char_array_destroy(self->address); + } + + if (self->language != NULL) { + char_array_destroy(self->language); + } + + if (self->country != NULL) { + char_array_destroy(self->country); + } + + free(self); +} \ No newline at end of file diff --git a/src/language_classifier_io.h b/src/language_classifier_io.h new file mode 100644 index 00000000..09a4dead --- /dev/null +++ b/src/language_classifier_io.h @@ -0,0 +1,49 @@ +#ifndef LANGUAGE_CLASSIFIER_IO_H +#define LANGUAGE_CLASSIFIER_IO_H + +#include +#include +#include + +#include "collections.h" +#include "features.h" +#include "file_utils.h" +#include "language_classifier.h" +#include "scanner.h" +#include "string_utils.h" + +#define AMBIGUOUS_LANGUAGE "xxx" +#define UNKNOWN_LANGUAGE "unk" + +#define LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE 1000 + +enum language_classifier_training_data_fields { + LANGUAGE_CLASSIFIER_FIELD_LANGUAGE, + LANGUAGE_CLASSIFIER_FIELD_COUNTRY, + LANGUAGE_CLASSIFIER_FIELD_ADDRESS, + LANGUAGE_CLASSIFIER_FILE_NUM_TOKENS +}; + +typedef struct language_classifier_data_set { + FILE *f; + token_array *tokens; + char_array *feature_array; + char_array *address; + char_array *language; + char_array *country; +} language_classifier_data_set_t; + +typedef struct language_classifier_minibatch { + feature_count_array *features; + cstring_array *labels; +} language_classifier_minibatch_t; + +language_classifier_data_set_t *language_classifier_data_set_init(char *filename); +bool language_classifier_data_set_next(language_classifier_data_set_t *self); +void language_classifier_data_set_destroy(language_classifier_data_set_t *self); + +language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, size_t batch_size, bool with_country); +language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, bool with_country); +void language_classifier_minibatch_destroy(language_classifier_minibatch_t *self); + +#endif \ No newline at end of file