Skip to content

Commit 90d6c6a

Browse files
prasunanand9prady9
authored andcommitted
add more definitions
1 parent 2bc76b2 commit 90d6c6a

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

ext/mri/arrayfire.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,10 @@ VALUE Util = Qnil;
2121
void Init_arrayfire();
2222

2323
af_dtype arf_dtype_from_rbsymbol(VALUE sym);
24+
af_source arf_source_from_rbsymbol(VALUE sym);
2425
af_mat_prop arf_mat_type_from_rbsymbol(VALUE sym);
26+
af_norm_type arf_norm_type_from_rbsymbol(VALUE sym);
27+
af_moment_type arf_moment_type_from_rbsymbol(VALUE sym);
2528

2629
static VALUE arf_init(int argc, VALUE* argv, VALUE self);
2730
static VALUE arf_alloc(VALUE klass);

ext/mri/cmodules/defines.c

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ const char* const DTYPE_NAMES[ARF_NUM_DTYPES] = {
1313
"u16"
1414
};
1515

16+
const char* const SOURCE_NAMES[ARF_NUM_SOURCES] = {
17+
"afDevice", ///< Device pointer
18+
"afHost" ///< Host pointer
19+
};
20+
1621
std::map<char*, size_t> MAT_PROPERTIES = {
1722
{"AF_MAT_NONE", 0},
1823
{"AF_MAT_TRANS", 1},
@@ -28,6 +33,25 @@ std::map<char*, size_t> MAT_PROPERTIES = {
2833
{"AF_MAT_BLOCK_DIAG", 8192}
2934
};
3035

36+
const char* const NORM_TYPES[ARF_NUM_NORM_TYPES] = {
37+
"AF_NORM_VECTOR_1", ///< treats the input as a vector and returns the sum of absolute values
38+
"AF_NORM_VECTOR_INF", ///< treats the input as a vector and returns the max of absolute values
39+
"AF_NORM_VECTOR_2", ///< treats the input as a vector and returns euclidean norm
40+
"AF_NORM_VECTOR_P", ///< treats the input as a vector and returns the p-norm
41+
"AF_NORM_MATRIX_1", ///< return the max of column sums
42+
"AF_NORM_MATRIX_INF", ///< return the max of row sums
43+
"AF_NORM_MATRIX_2", ///< returns the max singular value). Currently NOT SUPPORTED
44+
"AF_NORM_MATRIX_L_PQ", ///< returns Lpq-norm
45+
"AF_NORM_EUCLID" ///< The default. Same as AF_NORM_VECTOR_2
46+
};
47+
48+
std::map<char*, size_t> MOMENT_TYPES = {
49+
{"AF_MOMENT_M00", 1},
50+
{"AF_MOMENT_M01", 2},
51+
{"AF_MOMENT_M10", 4},
52+
{"AF_MOMENT_M11", 8}
53+
//AF_MOMENT_FIRST_ORDER = AF_MOMENT_M00 | AF_MOMENT_M01 | AF_MOMENT_M10 | AF_MOMENT_M11
54+
};
3155

3256
af_dtype arf_dtype_from_rbsymbol(VALUE sym) {
3357
ID sym_id = SYM2ID(sym);
@@ -42,6 +66,19 @@ af_dtype arf_dtype_from_rbsymbol(VALUE sym) {
4266
rb_raise(rb_eArgError, "invalid data type symbol (:%s) specified", RSTRING_PTR(str));
4367
}
4468

69+
af_source arf_source_from_rbsymbol(VALUE sym) {
70+
ID sym_id = SYM2ID(sym);
71+
72+
for (size_t index = 0; index < ARF_NUM_SOURCES; ++index) {
73+
if (sym_id == rb_intern(SOURCE_NAMES[index])) {
74+
return static_cast<af_source>(index);
75+
}
76+
}
77+
78+
VALUE str = rb_any_to_s(sym);
79+
rb_raise(rb_eArgError, "invalid data type symbol (:%s) specified", RSTRING_PTR(str));
80+
}
81+
4582
af_mat_prop arf_mat_type_from_rbsymbol(VALUE sym) {
4683
ID sym_id = SYM2ID(sym);
4784

@@ -54,3 +91,29 @@ af_mat_prop arf_mat_type_from_rbsymbol(VALUE sym) {
5491
VALUE str = rb_any_to_s(sym);
5592
rb_raise(rb_eArgError, "invalid matrix type symbol (:%s) specified", RSTRING_PTR(str));
5693
}
94+
95+
af_norm_type arf_norm_type_from_rbsymbol(VALUE sym) {
96+
ID sym_id = SYM2ID(sym);
97+
98+
for (size_t index = 0; index < ARF_NUM_NORM_TYPES; ++index) {
99+
if (sym_id == rb_intern(NORM_TYPES[index])) {
100+
return static_cast<af_norm_type>(index);
101+
}
102+
}
103+
104+
VALUE str = rb_any_to_s(sym);
105+
rb_raise(rb_eArgError, "invalid norm type symbol (:%s) specified", RSTRING_PTR(str));
106+
}
107+
108+
af_moment_type arf_moment_type_from_rbsymbol(VALUE sym) {
109+
ID sym_id = SYM2ID(sym);
110+
111+
for(std::map<char*, size_t>::value_type& entry : MOMENT_TYPES) {
112+
if (sym_id == rb_intern(entry.first)) {
113+
return static_cast<af_moment_type>(entry.second);
114+
}
115+
}
116+
117+
VALUE str = rb_any_to_s(sym);
118+
rb_raise(rb_eArgError, "invalid moment type symbol (:%s) specified", RSTRING_PTR(str));
119+
}

ext/mri/ruby_arrayfire.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
#include <arrayfire.h>
55
#include <stdio.h>
66
#include <math.h>
7+
#include <map>
8+
79
/*
810
* Project Includes
911
*/

ext/mri/ruby_arrayfire.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ typedef struct RANDOM_ENGINE_STRUCT
1515
}afrandomenginestruct;
1616

1717
#define ARF_NUM_DTYPES 12
18+
#define ARF_NUM_SOURCES 2
19+
#define ARF_NUM_NORM_TYPES 9
1820

1921
#ifndef HAVE_RB_ARRAY_CONST_PTR
2022
static inline const VALUE *

0 commit comments

Comments
 (0)