Skip to content

Commit 31a1bbd

Browse files
committed
Added wrapper for simple model initialization
1 parent be66671 commit 31a1bbd

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed

Source/embeddedML.c

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
/*
22
* EMBEDDEDML v1.3b
3+
* Lightweight Version
4+
* - Optimized for resource constrained platforms.
35
*/
46

57
/*
@@ -198,3 +200,118 @@ float relu2_derivative(float x){
198200
else if(x > 1.0) return 0.1;
199201
return 1.0;
200202
}
203+
204+
//----Wrapper Functions-----
205+
206+
void set_model_memory(ANN *model, float *weights, float *dedw, float *bias, float *output){
207+
model->weights = weights;
208+
model->dedw = dedw;
209+
model->bias = bias;
210+
model->output = output;
211+
}
212+
213+
void set_model_parameters(ANN *model, unsigned int *topology, unsigned int nlayers, char activation_function){
214+
model->topology = topology;
215+
model->n_layers = nlayers;
216+
217+
int i;
218+
int nweights = 0, nbias = 0;
219+
for(i = 1; i < nlayers; i++){
220+
nweights += topology[i]*topology[i-1];
221+
nbias += topology[i-1];
222+
}
223+
224+
model->n_weights = nweights;
225+
model->n_bias = nbias;
226+
227+
switch(activation_function){
228+
case 'r':
229+
model->output_activation_function = &relu;
230+
model->hidden_activation_function = &relu;
231+
break;
232+
case 'R':
233+
model->output_activation_function = &relu2;
234+
model->hidden_activation_function = &relu2;
235+
break;
236+
case 's':
237+
model->output_activation_function = &sigmoid;
238+
model->hidden_activation_function = &sigmoid;
239+
break;
240+
case 't':
241+
model->output_activation_function = &tanhf;
242+
model->hidden_activation_function = &tanhf;
243+
break;
244+
default:
245+
model->output_activation_function = &relu;
246+
model->hidden_activation_function = &relu;
247+
break;
248+
}
249+
}
250+
251+
void set_model_hyperparameters(ANN *model, float learning_rate, float bias_learning_rate, float momentum_factor){
252+
model->eta = learning_rate;
253+
model->beta = bias_learning_rate;
254+
model->alpha = momentum_factor;
255+
}
256+
257+
void set_learning_rate(ANN *model, float eta){
258+
model->eta = eta;
259+
}
260+
261+
void set_bias_learning_rate(ANN *model, float beta){
262+
model->beta = beta;
263+
}
264+
265+
void set_momentum_factor(ANN *model, float alpha){
266+
model->alpha = alpha;
267+
}
268+
269+
void set_output_actfunc(ANN *model, char func){
270+
switch(func){
271+
case 'r':
272+
model->output_activation_function = &relu;
273+
model->output_activation_derivative = &relu_derivative;
274+
break;
275+
case 'R':
276+
model->output_activation_function = &relu2;
277+
model->output_activation_derivative = &relu2_derivative;
278+
break;
279+
case 's':
280+
model->output_activation_function = &sigmoid;
281+
model->output_activation_derivative = &sigmoid_derivative;
282+
break;
283+
case 't':
284+
model->output_activation_function = &tanhf;
285+
model->output_activation_derivative = &tanhf_derivative;
286+
break;
287+
default:
288+
model->output_activation_function = &relu;
289+
model->output_activation_derivative = &relu_derivative;
290+
break;
291+
}
292+
}
293+
294+
void set_hidden_actfunc(ANN *model, char func){
295+
switch(func){
296+
case 'r':
297+
model->hidden_activation_function = &relu;
298+
model->hidden_activation_derivative = &relu_derivative;
299+
break;
300+
case 'R':
301+
model->hidden_activation_function = &relu2;
302+
model->hidden_activation_derivative = &relu2_derivative;
303+
break;
304+
case 's':
305+
model->hidden_activation_function = &sigmoid;
306+
model->hidden_activation_derivative = &sigmoid_derivative;
307+
break;
308+
case 't':
309+
model->hidden_activation_function = &tanhf;
310+
model->hidden_activation_derivative = &tanhf_derivative;
311+
break;
312+
default:
313+
model->hidden_activation_function = &relu;
314+
model->hidden_activation_derivative = &relu_derivative;
315+
break;
316+
}
317+
}

Source/embeddedML.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,16 @@ void run_ann(ANN *net, float *input);
5050
void init_ann(ANN *net);
5151
void init_pretrained_ann(ANN *net);
5252

53+
void set_model_memory(ANN *model, float *weights, float *dedw, float *bias, float *output);
54+
void set_model_parameters(ANN *model, unsigned int *topology, unsigned int nlayers, char activation_function);
55+
void set_model_hyperparameters(ANN *model, float learning_rate, float bias_learning_rate, float momentum_factor);
56+
57+
void set_learning_rate(ANN *model, float eta);
58+
void set_bias_learning_rate(ANN *model, float beta);
59+
void set_momentum_factor(ANN *model, float alpha);
60+
void set_output_actfunc(ANN *model, char func);
61+
void set_hidden_actfunc(ANN *model, char func);
62+
5363
//-----Utility-----
5464
void fill_zeros(float *v, unsigned int size);
5565
void fill_number(float *v, unsigned int size, float number);

0 commit comments

Comments
 (0)