diff --git a/src/vector_math.h b/src/vector_math.h index 875289b1..202ab64f 100644 --- a/src/vector_math.h +++ b/src/vector_math.h @@ -36,6 +36,57 @@ return cpy; \ } \ \ + static inline type name##_max(name *vector, size_t n) { \ + type max_val = 0; \ + type val; \ + for (int i = 0; i < n; i++) { \ + val = vector->a[i]; \ + if (val > max_val) max_val = val; \ + } \ + return max_val; \ + } \ + \ + static inline type name##_min(name *vector, size_t n) { \ + if (n < 1) return (type) 0; \ + type val = vector->a[0]; \ + type min_val = val; \ + for (int i = 1; i < n; i++) { \ + val = vector->a[i]; \ + if (val < min_val) min_val = val; \ + } \ + return min_val; \ + } \ + \ + static inline int64_t name##_argmax(name *vector, size_t n) { \ + if (n < 1) return -1; \ + type max_val = 0; \ + int64_t argmax = 0; \ + type val; \ + for (int i = 0; i < n; i++) { \ + val = vector->a[i]; \ + if (val > max_val) { \ + max_val = val; \ + argmax = i; \ + } \ + } \ + return argmax; \ + } \ + \ + static inline int64_t name##_argmin(name *vector, size_t n) { \ + if (n < 1) return (type) -1; \ + type val = vector->a[0]; \ + type min_val = val; \ + int64_t argmin = 0; \ + for (int i = 1; i < n; i++) { \ + val = vector->a[i]; \ + if (val < min_val) { \ + min_val = val; \ + argmin = i; \ + } \ + } \ + return argmin; \ + } \ + \ static inline void name##_add(name *vector, type c, size_t n) { \ for (int i = 0; i < n; i++) { \ vector->a[i] += c; \ @@ -96,6 +147,15 @@ return result; \ } \ \ + static inline type name##_log_sum_exp(name *vector, size_t n) { \ + type max = name##_max(vector, n); \ + type result = 0; \ + for (int i = 0; i < n; i++) { \ + result += exp(vector->a[i] - max); \ + } \ + return max + log(result); \ + } \ + \ static inline void name##_add_vector(name *v1, name *v2, size_t n) { \ for (int i = 0; i < n; i++) { \ v1->a[i] += v2->a[i]; \