diff --git a/src/regularization.c b/src/regularization.c new file mode 100644 index 00000000..c6c20fee --- /dev/null +++ b/src/regularization.c @@ -0,0 +1,31 @@ +#include "regularization.h" +#include "float_utils.h" +#include "log/log.h" + +inline void regularize_l2(double *theta, size_t n, double reg_update) { + for (size_t i = 0; i < n; i++) { + double current_value = theta[i]; + double updated_value = current_value - current_value * reg_update; + // Make sure the regularization update doesn't change the sign of the weight + // Otherwise, set the weight to 0 + if ((updated_value > 0) == (current_value > 0)) { + theta[i] = updated_value; + } else { + theta[i] = 0.0; + } + } +} + +inline void regularize_l1(double *theta, size_t n, double reg_update) { + for (size_t i = 0; i < n; i++) { + double current_value = theta[i]; + double updated_value = current_value - sign(current_value) * reg_update; + // Make sure the regularization update doesn't change the sign of the weight + // Otherwise, set the weight to 0 + if ((updated_value > 0) == (current_value > 0)) { + theta[i] = updated_value; + } else { + theta[i] = 0.0; + } + } +} \ No newline at end of file diff --git a/src/regularization.h b/src/regularization.h new file mode 100644 index 00000000..30d10662 --- /dev/null +++ b/src/regularization.h @@ -0,0 +1,15 @@ +#ifndef REGULARIZATION_H +#define REGULARIZATION_H + +#include + +typedef enum { + REGULARIZATION_NONE, + REGULARIZATION_L1, + REGULARIZATION_L2 +} regularization_type_t; + +void regularize_l2(double *theta, size_t n, double reg_update); +void regularize_l1(double *theta, size_t n, double reg_update); + +#endif \ No newline at end of file