[language_classification] Automatic hyperparameter optimization using either the cross-validation set or two distinct subsets of the training set

This commit is contained in:
Al
2016-01-17 21:11:37 -05:00
parent af5689ee52
commit f808f74271
6 changed files with 299 additions and 112 deletions

View File

@@ -121,7 +121,7 @@ inline bool language_classifier_language_is_valid(char *language) {
return !string_equals(language, AMBIGUOUS_LANGUAGE) && !string_equals(language, UNKNOWN_LANGUAGE);
}
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, size_t batch_size, bool with_country) {
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, khash_t(str_uint32) *labels, size_t batch_size) {
size_t in_batch = 0;
language_classifier_minibatch_t *minibatch = NULL;
@@ -131,17 +131,19 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with
if (strlen(address) == 0) {
continue;
}
char *country = NULL;
if (with_country) {
country = char_array_get_string(self->country);
}
char *country = NULL;
//char *country = char_array_get_string(self->country);
char *language = char_array_get_string(self->language);
if (!language_classifier_language_is_valid(language)) {
continue;
}
if (labels != NULL && kh_get(str_uint32, labels, language) == kh_end(labels)) {
continue;
}
if (minibatch == NULL) {
minibatch = language_classifier_minibatch_new();
if (minibatch == NULL) {
@@ -150,13 +152,16 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with
}
}
khash_t(str_double) *feature_counts = extract_language_features(address, country, self->tokens, self->feature_array);
if (feature_counts == NULL) {
log_error("Could not extract features for: %s\n", address);
language_classifier_minibatch_destroy(minibatch);
return NULL;
if (labels != NULL) {
khash_t(str_double) *feature_counts = extract_language_features(address, country, self->tokens, self->feature_array);
if (feature_counts == NULL) {
log_error("Could not extract features for: %s\n", address);
language_classifier_minibatch_destroy(minibatch);
return NULL;
}
feature_count_array_push(minibatch->features, feature_counts);
}
feature_count_array_push(minibatch->features, feature_counts);
cstring_array_add_string(minibatch->labels, language);
in_batch++;
}
@@ -164,8 +169,8 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with
return minibatch;
}
inline language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, bool with_country) {
return language_classifier_data_set_get_minibatch_with_size(self, LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE, with_country);
inline language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, khash_t(str_uint32) *labels) {
return language_classifier_data_set_get_minibatch_with_size(self, labels, LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE);
}
void language_classifier_data_set_destroy(language_classifier_data_set_t *self) {