diff --git a/src/averaged_perceptron_trainer.c b/src/averaged_perceptron_trainer.c index a82c3721..c150af67 100644 --- a/src/averaged_perceptron_trainer.c +++ b/src/averaged_perceptron_trainer.c @@ -316,20 +316,27 @@ uint32_t averaged_perceptron_trainer_predict(averaged_perceptron_trainer_t *self double_array_zero(scores->a, scores->n); + uint64_t *update_counts = self->update_counts->a; + cstring_array_foreach(features, i, feature, { if (!averaged_perceptron_trainer_get_feature_id(self, feature, &feature_id, add_if_missing)) { continue; } - weights = averaged_perceptron_trainer_get_class_weights(self, feature_id, add_if_missing); + uint64_t update_count = update_counts[feature_id]; + bool keep_feature = update_count >= self->min_updates; - if (weights == NULL) { - continue; + if (keep_feature) { + weights = averaged_perceptron_trainer_get_class_weights(self, feature_id, add_if_missing); + + if (weights == NULL) { + continue; + } + + kh_foreach(weights, class_id, weight, { + scores->a[class_id] += weight.value; + }) } - - kh_foreach(weights, class_id, weight, { - scores->a[class_id] += weight.value; - }) }) int64_t max_score = double_array_argmax(scores->a, scores->n);