[utils] Adding min, max, argmin, argmax and log_sum_exp to generic vector math header
This commit is contained in:
@@ -36,6 +36,57 @@
|
|||||||
return cpy; \
|
return cpy; \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
|
static inline type name##_max(name *vector, size_t n) { \
|
||||||
|
type max_val = 0; \
|
||||||
|
type val; \
|
||||||
|
for (int i = 0; i < n; i++) { \
|
||||||
|
val = vector->a[i]; \
|
||||||
|
if (val > max_val) max_val = val; \
|
||||||
|
} \
|
||||||
|
return max_val; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
static inline type name##_min(name *vector, size_t n) { \
|
||||||
|
if (n < 1) return (type) 0; \
|
||||||
|
type val = vector->a[0]; \
|
||||||
|
type min_val = val; \
|
||||||
|
for (int i = 1; i < n; i++) { \
|
||||||
|
val = vector->a[i]; \
|
||||||
|
if (val < min_val) min_val = val; \
|
||||||
|
} \
|
||||||
|
return min_val; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
static inline int64_t name##_argmax(name *vector, size_t n) { \
|
||||||
|
if (n < 1) return -1; \
|
||||||
|
type max_val = 0; \
|
||||||
|
int64_t argmax = 0; \
|
||||||
|
type val; \
|
||||||
|
for (int i = 0; i < n; i++) { \
|
||||||
|
val = vector->a[i]; \
|
||||||
|
if (val > max_val) { \
|
||||||
|
max_val = val; \
|
||||||
|
argmax = i; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
return argmax; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
static inline int64_t name##_argmin(name *vector, size_t n) { \
|
||||||
|
if (n < 1) return (type) -1; \
|
||||||
|
type val = vector->a[0]; \
|
||||||
|
type min_val = val; \
|
||||||
|
int64_t argmin = 0; \
|
||||||
|
for (int i = 1; i < n; i++) { \
|
||||||
|
val = vector->a[i]; \
|
||||||
|
if (val < min_val) { \
|
||||||
|
min_val = val; \
|
||||||
|
argmin = i; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
return argmin; \
|
||||||
|
} \
|
||||||
|
\
|
||||||
static inline void name##_add(name *vector, type c, size_t n) { \
|
static inline void name##_add(name *vector, type c, size_t n) { \
|
||||||
for (int i = 0; i < n; i++) { \
|
for (int i = 0; i < n; i++) { \
|
||||||
vector->a[i] += c; \
|
vector->a[i] += c; \
|
||||||
@@ -96,6 +147,15 @@
|
|||||||
return result; \
|
return result; \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
|
static inline type name##_log_sum_exp(name *vector, size_t n) { \
|
||||||
|
type max = name##_max(vector, n); \
|
||||||
|
type result = 0; \
|
||||||
|
for (int i = 0; i < n; i++) { \
|
||||||
|
result += exp(vector->a[i] - max); \
|
||||||
|
} \
|
||||||
|
return max + log(result); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
static inline void name##_add_vector(name *v1, name *v2, size_t n) { \
|
static inline void name##_add_vector(name *v1, name *v2, size_t n) { \
|
||||||
for (int i = 0; i < n; i++) { \
|
for (int i = 0; i < n; i++) { \
|
||||||
v1->a[i] += v2->a[i]; \
|
v1->a[i] += v2->a[i]; \
|
||||||
|
|||||||
Reference in New Issue
Block a user