1111from keras2c .check_model import check_model
1212from keras2c .make_test_suite import make_test_suite
1313import numpy as np
14+ import subprocess
1415import keras
1516import tensorflow as tf
1617tf .compat .v1 .disable_eager_execution ()
2324__email__ = "wconlin@princeton.edu"
2425
2526
26- def model2c (model , file , function_name , malloc = False , verbose = True ):
27+ def model2c (model , function_name , malloc = False , verbose = True ):
2728 """Generates C code for model
2829
2930 Writes main function definition to "function_name.c" and a public header
3031 with declarations to "function_name.h"
3132
3233 Args:
3334 model (keras Model): model to convert
34- file (open file instance): where to write main function
3535 function_name (str): name of C function
3636 malloc (bool): whether to allocate variables on the stack or heap
3737 verbose (bool): whether to print info to stdout
@@ -42,17 +42,17 @@ def model2c(model, file, function_name, malloc=False, verbose=True):
4242 """
4343
4444 model_inputs , model_outputs = get_model_io_names (model )
45- s = '#include <math.h> \n '
46- s += '#include <string.h> \n '
47- s += '#include "./include/k2c_include.h" \n '
48- s += '#include "./include/k2c_tensor_include.h" \n '
49- s += '\n \n '
50- file .write (s )
45+ includes = '#include <math.h> \n '
46+ includes += '#include <string.h> \n '
47+ includes += '#include "./include/k2c_include.h" \n '
48+ includes += '#include "./include/k2c_tensor_include.h" \n '
49+ includes += '\n \n '
5150
5251 if verbose :
5352 print ('Gathering Weights' )
5453 stack_vars , malloc_vars , static_vars = Weights2C (
5554 model , function_name , malloc ).write_weights (verbose )
55+ stateful = len (static_vars ) > 0
5656 layers = Layers2C (model , malloc ).write_layers (verbose )
5757
5858 function_signature = 'void ' + function_name + '('
@@ -64,107 +64,119 @@ def model2c(model, file, function_name, malloc=False, verbose=True):
6464 function_signature += ',' + ',' .join (['float* ' +
6565 key for key in malloc_vars .keys ()])
6666 function_signature += ')'
67- file .write (static_vars + '\n \n ' )
68- file .write (function_signature )
69- file .write (' { \n \n ' )
70- file .write (stack_vars )
71- file .write (layers )
72- file .write ('\n } \n \n ' )
73- stateful = len (static_vars ) > 0
7467
75- init_sig = write_function_initialize (file , function_name , malloc_vars )
76- term_sig = write_function_terminate (file , function_name , malloc_vars )
77- if stateful :
78- reset_sig = write_function_reset (file , function_name )
68+ init_sig , init_fun = gen_function_initialize (function_name , malloc_vars )
69+ term_sig , term_fun = gen_function_terminate (function_name , malloc_vars )
70+ reset_sig , reset_fun = gen_function_reset (function_name )
71+
72+ with open (function_name + '.c' , 'x+' ) as source :
73+ source .write (includes )
74+ source .write (static_vars + '\n \n ' )
75+ source .write (function_signature )
76+ source .write (' { \n \n ' )
77+ source .write (stack_vars )
78+ source .write (layers )
79+ source .write ('\n } \n \n ' )
80+ source .write (init_fun )
81+ source .write (term_fun )
82+ if stateful :
83+ source .write (reset_fun )
84+
7985 with open (function_name + '.h' , 'x+' ) as header :
8086 header .write ('#pragma once \n ' )
8187 header .write ('#include "./include/k2c_tensor_include.h" \n ' )
8288 header .write (function_signature + '; \n ' )
8389 header .write (init_sig + '; \n ' )
8490 header .write (term_sig + '; \n ' )
85-
8691 if stateful :
8792 header .write (reset_sig + '; \n ' )
93+ if not subprocess .run (['astyle' , '--version' ]).returncode :
94+ subprocess .run (['astyle' , '-n' , function_name + '.h' ])
95+ subprocess .run (['astyle' , '-n' , function_name + '.c' ])
8896
8997 return malloc_vars .keys (), stateful
9098
9199
92- def write_function_reset ( file , function_name ):
100+ def gen_function_reset ( function_name ):
93101 """Writes a reset function for stateful models
94102
95103 Reset function is used to clear internal state of the model
96104
97105 Args:
98- file (open file instance): file to write to
99106 function_name (str): name of main function
100107
101108 Returns:
102109 signature (str): delcaration of the reset function
110+ function (str): definition of the reset function
103111 """
104112
105- function_reset_signature = 'void ' + function_name + '_reset_states()'
106- file . write ( function_reset_signature )
107- s = ' { \n \n '
108- s += 'memset(&' + function_name + \
109- '_states,0,sizeof( ' + function_name + '_states)); \n '
110- s += "} \n \n "
111- file . write ( s )
112- return function_reset_signature
113+ reset_sig = 'void ' + function_name + '_reset_states()'
114+
115+ reset_fun = reset_sig
116+ reset_fun += ' { \n \n '
117+ reset_fun += 'memset(& ' + function_name + \
118+ '_states,0,sizeof(' + function_name + '_states)); \n '
119+ reset_fun += "} \n \n "
120+ return reset_sig , reset_fun
113121
114122
115- def write_function_initialize ( file , function_name , malloc_vars ):
123+ def gen_function_initialize ( function_name , malloc_vars ):
116124 """Writes an initialize function
117125
118126 Initialize function is used to load variables into memory and do other start up tasks
119127
120128 Args:
121- file (open file instance): file to write to
122129 function_name (str): name of main function
130+ malloc_vars (dict): variables to read in
123131
124132 Returns:
125133 signature (str): delcaration of the initialization function
134+ function (str): definition of the initialization function
126135 """
127136
128- function_init_signature = 'void ' + function_name + '_initialize('
129- function_init_signature += ',' .join (['float** ' +
130- key + ' \n ' for key in malloc_vars .keys ()])
131- function_init_signature += ')'
132- file .write (function_init_signature )
133- s = ' { \n \n '
137+ init_sig = 'void ' + function_name + '_initialize('
138+ init_sig += ',' .join (['float** ' +
139+ key + ' \n ' for key in malloc_vars .keys ()])
140+ init_sig += ')'
141+
142+ init_fun = init_sig
143+ init_fun += ' { \n \n '
134144 for key in malloc_vars .keys ():
135145 fname = function_name + key + ".csv"
136146 np .savetxt (fname , malloc_vars [key ], fmt = "%.8e" , delimiter = ',' )
137- s += '*' + key + " = k2c_read_array(\" " + \
147+ init_fun += '*' + key + " = k2c_read_array(\" " + \
138148 fname + "\" ," + str (malloc_vars [key ].size ) + "); \n "
139- s += "} \n \n "
140- file . write ( s )
141- return function_init_signature
149+ init_fun += "} \n \n "
150+
151+ return init_sig , init_fun
142152
143153
144- def write_function_terminate ( file , function_name , malloc_vars ):
154+ def gen_function_terminate ( function_name , malloc_vars ):
145155 """Writes a terminate function
146156
147157 Terminate function is used to deallocate memory after completion
148158
149159 Args:
150- file (open file instance): file to write to
151160 function_name (str): name of main function
161+ malloc_vars (dict): variables to deallocate
152162
153163 Returns:
154164 signature (str): delcaration of the terminate function
165+ function (str): definition of the terminate function
155166 """
156167
157- function_term_signature = 'void ' + function_name + '_terminate('
158- function_term_signature += ',' .join (['float* ' +
159- key for key in malloc_vars .keys ()])
160- function_term_signature += ')'
161- file .write (function_term_signature )
162- s = ' { \n \n '
168+ term_sig = 'void ' + function_name + '_terminate('
169+ term_sig += ',' .join (['float* ' +
170+ key for key in malloc_vars .keys ()])
171+ term_sig += ')'
172+
173+ term_fun = term_sig
174+ term_fun += ' { \n \n '
163175 for key in malloc_vars .keys ():
164- s += "free(" + key + "); \n "
165- s += "} \n \n "
166- file . write ( s )
167- return function_term_signature
176+ term_fun += "free(" + key + "); \n "
177+ term_fun += "} \n \n "
178+
179+ return term_sig , term_fun
168180
169181
170182def k2c (model , function_name , malloc = False , num_tests = 10 , verbose = True ):
@@ -200,10 +212,9 @@ def k2c(model, function_name, malloc=False, num_tests=10, verbose=True):
200212 if verbose :
201213 print ('All checks passed' )
202214
203- file = open (filename , "x+" )
204215 malloc_vars , stateful = model2c (
205- model , file , function_name , malloc , verbose )
206- file . close ()
216+ model , function_name , malloc , verbose )
217+
207218 s = 'Done \n '
208219 s += "C code is in '" + function_name + \
209220 ".c' with header file '" + function_name + ".h' \n "
0 commit comments