Skip to content

Commit 03e9baa

Browse files
prasunanand9prady9
authored andcommitted
init multiple dtypes
1 parent aa8df4d commit 03e9baa

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

ext/mri/arrayfire.c

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ VALUE Util = Qnil;
1919

2020
// prototypes
2121
void Init_arrayfire();
22+
23+
af_dtype arf_dtype_from_rbsymbol(VALUE sym);
2224
static VALUE arf_init(int argc, VALUE* argv, VALUE self);
2325
static VALUE arf_alloc(VALUE klass);
2426
static void arf_free(afstruct* af);
@@ -598,6 +600,7 @@ VALUE arf_init(int argc, VALUE* argv, VALUE self)
598600
afstruct* afarray;
599601
Data_Get_Struct(self, afstruct, afarray);
600602
if(argc > 0){
603+
af_dtype dtype = (argc == 4) ? arf_dtype_from_rbsymbol(argv[3]) : f32;
601604

602605
dim_t ndims = (dim_t)FIX2LONG(argv[0]);
603606
dim_t* dimensions = (dim_t*)malloc(ndims * sizeof(dim_t));
@@ -611,15 +614,13 @@ VALUE arf_init(int argc, VALUE* argv, VALUE self)
611614
host_array[index] = (float)NUM2DBL(RARRAY_AREF(argv[2], index));
612615
}
613616

614-
af_create_array(&afarray->carray, host_array, ndims, dimensions, f32);
615-
617+
af_create_array(&afarray->carray, host_array, ndims, dimensions, dtype);
616618
af_print_array(afarray->carray);
617619
}
618620

619621
return self;
620622
}
621623

622-
623624
static VALUE arf_alloc(VALUE klass)
624625
{
625626
/* allocate */
@@ -741,11 +742,12 @@ DEF_UNARY_RUBY_ACCESSOR(ceil, ceil)
741742
#include "cmodules/backend.c"
742743
#include "cmodules/blas.c"
743744
#include "cmodules/cuda.c"
745+
#include "cmodules/data.c"
746+
#include "cmodules/defines.c"
744747
#include "cmodules/device.c"
745748
#include "cmodules/index.c"
746-
#include "cmodules/opencl.c"
747-
#include "cmodules/data.c"
748749
#include "cmodules/lapack.c"
750+
#include "cmodules/opencl.c"
749751
#include "cmodules/random.c"
750752
#include "cmodules/statistics.c"
751753
#include "cmodules/util.c"

ext/mri/cmodules/defines.c

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
const char* const DTYPE_NAMES[ARF_NUM_DTYPES] = {
2+
"f32",
3+
"c32",
4+
"f64",
5+
"c64",
6+
"b8",
7+
"s32",
8+
"u32",
9+
"u8",
10+
"s64",
11+
"u64",
12+
"s16",
13+
"u16"
14+
};
15+
16+
af_dtype arf_dtype_from_rbsymbol(VALUE sym) {
17+
ID sym_id = SYM2ID(sym);
18+
19+
for (size_t index = 0; index < ARF_NUM_DTYPES; ++index) {
20+
if (sym_id == rb_intern(DTYPE_NAMES[index])) {
21+
return static_cast<af_dtype>(index);
22+
}
23+
}
24+
25+
VALUE str = rb_any_to_s(sym);
26+
rb_raise(rb_eArgError, "invalid data type symbol (:%s) specified", RSTRING_PTR(str));
27+
}

ext/mri/ruby_arrayfire.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ typedef struct RANDOM_ENGINE_STRUCT
1414
af_random_engine cengine;
1515
}afrandomenginestruct;
1616

17+
#define ARF_NUM_DTYPES 12
18+
1719
#ifndef HAVE_RB_ARRAY_CONST_PTR
1820
static inline const VALUE *
1921
rb_array_const_ptr(VALUE a)

0 commit comments

Comments
 (0)