[language_classification] Automatic hyperparameter optimization using either the cross-validation set or two distinct subsets of the training set
This commit is contained in:
@@ -121,7 +121,7 @@ inline bool language_classifier_language_is_valid(char *language) {
|
||||
return !string_equals(language, AMBIGUOUS_LANGUAGE) && !string_equals(language, UNKNOWN_LANGUAGE);
|
||||
}
|
||||
|
||||
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, size_t batch_size, bool with_country) {
|
||||
language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with_size(language_classifier_data_set_t *self, khash_t(str_uint32) *labels, size_t batch_size) {
|
||||
size_t in_batch = 0;
|
||||
|
||||
language_classifier_minibatch_t *minibatch = NULL;
|
||||
@@ -131,17 +131,19 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with
|
||||
if (strlen(address) == 0) {
|
||||
continue;
|
||||
}
|
||||
char *country = NULL;
|
||||
|
||||
if (with_country) {
|
||||
country = char_array_get_string(self->country);
|
||||
}
|
||||
char *country = NULL;
|
||||
//char *country = char_array_get_string(self->country);
|
||||
|
||||
char *language = char_array_get_string(self->language);
|
||||
if (!language_classifier_language_is_valid(language)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (labels != NULL && kh_get(str_uint32, labels, language) == kh_end(labels)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (minibatch == NULL) {
|
||||
minibatch = language_classifier_minibatch_new();
|
||||
if (minibatch == NULL) {
|
||||
@@ -150,13 +152,16 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with
|
||||
}
|
||||
}
|
||||
|
||||
khash_t(str_double) *feature_counts = extract_language_features(address, country, self->tokens, self->feature_array);
|
||||
if (feature_counts == NULL) {
|
||||
log_error("Could not extract features for: %s\n", address);
|
||||
language_classifier_minibatch_destroy(minibatch);
|
||||
return NULL;
|
||||
if (labels != NULL) {
|
||||
khash_t(str_double) *feature_counts = extract_language_features(address, country, self->tokens, self->feature_array);
|
||||
if (feature_counts == NULL) {
|
||||
log_error("Could not extract features for: %s\n", address);
|
||||
language_classifier_minibatch_destroy(minibatch);
|
||||
return NULL;
|
||||
}
|
||||
feature_count_array_push(minibatch->features, feature_counts);
|
||||
}
|
||||
feature_count_array_push(minibatch->features, feature_counts);
|
||||
|
||||
cstring_array_add_string(minibatch->labels, language);
|
||||
in_batch++;
|
||||
}
|
||||
@@ -164,8 +169,8 @@ language_classifier_minibatch_t *language_classifier_data_set_get_minibatch_with
|
||||
return minibatch;
|
||||
}
|
||||
|
||||
inline language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, bool with_country) {
|
||||
return language_classifier_data_set_get_minibatch_with_size(self, LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE, with_country);
|
||||
inline language_classifier_minibatch_t *language_classifier_data_set_get_minibatch(language_classifier_data_set_t *self, khash_t(str_uint32) *labels) {
|
||||
return language_classifier_data_set_get_minibatch_with_size(self, labels, LANGUAGE_CLASSIFIER_DEFAULT_BATCH_SIZE);
|
||||
}
|
||||
|
||||
void language_classifier_data_set_destroy(language_classifier_data_set_t *self) {
|
||||
|
||||
Reference in New Issue
Block a user