Skip to content

Commit 4582232

Browse files
committed
Add all keras optimizers
Add custom class weighting Add configurable losses and binary crossentropy loss Remove batch size from model compile method Move metrics to >= for positive Add binary metrics Fix training=True for evaluation bug
1 parent 65f7d59 commit 4582232

File tree

24 files changed

+1250
-73
lines changed

24 files changed

+1250
-73
lines changed

data/single_file_dataset.go

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ type SingleFileDatasetConfig struct {
7474
RowFilter func(line []string) bool
7575
ConcurrentFileLimit int32
7676
MaxRowsForProcessorFit int
77+
ClassWeights map[int]float32
7778
}
7879

7980
func NewSingleFileDataset(
@@ -97,6 +98,10 @@ func NewSingleFileDataset(
9798
config.MaxRowsForProcessorFit = 1000000
9899
}
99100

101+
if config.ClassWeights == nil {
102+
config.ClassWeights = make(map[int]float32)
103+
}
104+
100105
var openFileCount int32
101106
var generatorOffset int32
102107

@@ -114,7 +119,7 @@ func NewSingleFileDataset(
114119
valPercent: config.ValPercent,
115120
testPercent: config.TestPercent,
116121
ClassCounts: make(map[int]int),
117-
ClassWeights: make(map[int]float32),
122+
ClassWeights: config.ClassWeights,
118123
filter: config.RowFilter,
119124
concurrentFileLimit: config.ConcurrentFileLimit,
120125
openFileCount: &openFileCount,
@@ -201,7 +206,9 @@ func (d *SingleFileDataset) readLineOffsets() error {
201206
d.lineOffsets = cache.LineOffsets
202207
d.Count = cache.Count
203208
d.ClassCounts = cache.ClassCounts
204-
d.ClassWeights = cache.ClassWeights
209+
if len(d.ClassWeights) == 0 {
210+
d.ClassWeights = cache.ClassWeights
211+
}
205212

206213
d.logger.InfoF("data", "Found %d rows. Got class counts: %#v Got class weights: %#v", d.Count, d.ClassCounts, d.ClassWeights)
207214

@@ -308,14 +315,16 @@ func (d *SingleFileDataset) readLineOffsets() error {
308315
swg.Wait()
309316
fmt.Println()
310317

311-
majorClassCount := 0
312-
for _, count := range d.ClassCounts {
313-
if count > majorClassCount {
314-
majorClassCount = count
318+
if len(d.ClassWeights) == 0 {
319+
majorClassCount := 0
320+
for _, count := range d.ClassCounts {
321+
if count > majorClassCount {
322+
majorClassCount = count
323+
}
324+
}
325+
for class, count := range d.ClassCounts {
326+
d.ClassWeights[class] = float32(majorClassCount) / float32(count)
315327
}
316-
}
317-
for class, count := range d.ClassCounts {
318-
d.ClassWeights[class] = float32(majorClassCount) / float32(count)
319328
}
320329

321330
cacheBytes, e := json.Marshal(fileStatsCache{

examples/class_weights/main.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/codingbeard/tfkg/layer"
1111
"github.com/codingbeard/tfkg/metric"
1212
"github.com/codingbeard/tfkg/model"
13+
"github.com/codingbeard/tfkg/optimizer"
1314
"github.com/codingbeard/tfkg/preprocessor"
1415
tf "github.com/galeone/tensorflow/tensorflow/go"
1516
"math/rand"
@@ -107,9 +108,7 @@ func main() {
107108

108109
// This part is pretty nasty under the hood. Effectively it will generate some python code for our model and execute it to save the model in a format we can load and train
109110
// A python binary must be available to use for this to work
110-
// The batchSize MUST match the batchSize in the call to Fit or Evaluate
111-
batchSize := 1000
112-
e = m.CompileAndLoad(batchSize, logsDir)
111+
e = m.CompileAndLoad(model.LossSparseCategoricalCrossentropy, optimizer.NewAdam(), logsDir)
113112
if e != nil {
114113
return
115114
}
@@ -130,7 +129,7 @@ func main() {
130129
model.FitConfig{
131130
Epochs: 10,
132131
Validation: true,
133-
BatchSize: batchSize,
132+
BatchSize: 1000,
134133
PreFetch: 10,
135134
Verbose: 1,
136135
Metrics: []metric.Metric{

examples/iris/main.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/codingbeard/tfkg/layer"
1111
"github.com/codingbeard/tfkg/metric"
1212
"github.com/codingbeard/tfkg/model"
13+
"github.com/codingbeard/tfkg/optimizer"
1314
"github.com/codingbeard/tfkg/preprocessor"
1415
tf "github.com/galeone/tensorflow/tensorflow/go"
1516
"os"
@@ -109,9 +110,7 @@ func main() {
109110

110111
// This part is pretty nasty under the hood. Effectively it will generate some python code for our model and execute it to save the model in a format we can load and train
111112
// A python binary must be available to use for this to work
112-
// The batchSize MUST match the batchSize in the call to Fit or Evaluate
113-
batchSize := 3
114-
e = m.CompileAndLoad(batchSize, saveDir)
113+
e = m.CompileAndLoad(model.LossSparseCategoricalCrossentropy, optimizer.NewAdam(), saveDir)
115114
if e != nil {
116115
return
117116
}
@@ -133,7 +132,7 @@ func main() {
133132
model.FitConfig{
134133
Epochs: 10,
135134
Validation: true,
136-
BatchSize: batchSize,
135+
BatchSize: 3,
137136
PreFetch: 10,
138137
Verbose: 1,
139138
Metrics: []metric.Metric{

examples/jobs/main.go

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/codingbeard/tfkg/layer"
1111
"github.com/codingbeard/tfkg/metric"
1212
"github.com/codingbeard/tfkg/model"
13+
"github.com/codingbeard/tfkg/optimizer"
1314
"github.com/codingbeard/tfkg/preprocessor"
1415
tf "github.com/galeone/tensorflow/tensorflow/go"
1516
"os"
@@ -271,12 +272,11 @@ func main() {
271272
layer.DenseWithActivation("swish"),
272273
)(mergedDense1)
273274

274-
// Get the number of classes from the dataset if we don't want to count them manually, but in this case it is only 2
275275
output := layer.NewDense(
276-
float64(dataset.NumCategoricalClasses()),
276+
1,
277277
layer.DenseWithDtype(layer.Float32),
278278
layer.DenseWithName("output"),
279-
layer.DenseWithActivation("softmax"),
279+
layer.DenseWithActivation("sigmoid"),
280280
)(mergedDense2)
281281

282282
// Define a keras style Functional model
@@ -289,9 +289,7 @@ func main() {
289289

290290
// This part is pretty nasty under the hood. Effectively it will generate some python code for our model and execute it to save the model in a format we can load and train
291291
// A python binary must be available to use for this to work
292-
// The batchSize MUST match the batch size in the call to Fit or Evaluate
293-
batchSize := 200
294-
e = m.CompileAndLoad(batchSize, saveDir)
292+
e = m.CompileAndLoad(model.LossBinaryCrossentropy, optimizer.NewAdam(), saveDir)
295293
if e != nil {
296294
return
297295
}
@@ -313,7 +311,7 @@ func main() {
313311
model.FitConfig{
314312
Epochs: 10,
315313
Validation: true,
316-
BatchSize: batchSize,
314+
BatchSize: 200,
317315
PreFetch: 10,
318316
Verbose: 1,
319317
Metrics: []metric.Metric{

examples/multiple_inputs/main.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/codingbeard/tfkg/layer"
1111
"github.com/codingbeard/tfkg/metric"
1212
"github.com/codingbeard/tfkg/model"
13+
"github.com/codingbeard/tfkg/optimizer"
1314
"github.com/codingbeard/tfkg/preprocessor"
1415
tf "github.com/galeone/tensorflow/tensorflow/go"
1516
"os"
@@ -111,7 +112,7 @@ func main() {
111112
// This part is pretty nasty under the hood. Effectively it will generate some python code for our model and execute it to save the model in a format we can load and train
112113
// A python binary must be available to use for this to work
113114
// The batchSize MUST match the batch size in the call to Fit or Evaluate
114-
e = m.CompileAndLoad(3, saveDir)
115+
e = m.CompileAndLoad(model.LossSparseCategoricalCrossentropy, optimizer.NewAdam(), saveDir)
115116
if e != nil {
116117
return
117118
}

generate/python/generate.go

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ func main() {
3636
}
3737
for _, object := range objects {
3838
fmt.Println(object.Type, object.Name)
39-
if object.Type == "initializer" {
39+
if object.Type == "optimizer" {
40+
createOptimizer(object)
41+
} else if object.Type == "initializer" {
4042
createInitializer(object)
4143
} else if object.Type == "regularizer" {
4244
createRegularizer(object)
@@ -51,6 +53,112 @@ func main() {
5153
if e != nil {
5254
panic(e)
5355
}
56+
57+
_, e = exec.Command("go", "fmt", "github.com/codingbeard/tfkg/layer/constraint").Output()
58+
if e != nil {
59+
panic(e)
60+
}
61+
62+
_, e = exec.Command("go", "fmt", "github.com/codingbeard/tfkg/layer/initializer").Output()
63+
if e != nil {
64+
panic(e)
65+
}
66+
67+
_, e = exec.Command("go", "fmt", "github.com/codingbeard/tfkg/layer/regularizer").Output()
68+
if e != nil {
69+
panic(e)
70+
}
71+
72+
_, e = exec.Command("go", "fmt", "github.com/codingbeard/tfkg/optimizer").Output()
73+
if e != nil {
74+
panic(e)
75+
}
76+
}
77+
78+
func createOptimizer(object objectJson) {
79+
e := os.MkdirAll("../../optimizer", os.ModePerm)
80+
if e != nil {
81+
panic(e)
82+
}
83+
84+
var setters []string
85+
var objectProperties []string
86+
objectPropertyNames := make(map[string]bool)
87+
for _, param := range getRequiredParams(object) {
88+
objectPropertyNames[param[0]] = true
89+
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", param[0], param[1]))
90+
}
91+
for _, param := range getOptionalParams(object) {
92+
objectPropertyNames[param[0]] = true
93+
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", param[0], param[1]))
94+
setters = append(setters, getOptionString(object.Name, param[0], param[1]))
95+
}
96+
97+
subConfig := object.Config["config"].(map[string]interface{})
98+
for originalName, value := range subConfig {
99+
name := snakeCaseToCamelCase(originalName)
100+
if _, ok := objectPropertyNames[name]; ok {
101+
continue
102+
}
103+
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", snakeCaseToCamelCase(name), getGolangTypeFromValue(object.Name, originalName, value)))
104+
}
105+
106+
var requiredParamSetters []string
107+
for _, paramName := range getRequiredParamNames(object) {
108+
requiredParamSetters = append(requiredParamSetters, fmt.Sprintf("%s: %s", paramName, paramName))
109+
}
110+
111+
var defaultParamSetters []string
112+
for _, param := range getOptionalParamDefaults(object) {
113+
defaultParamSetters = append(defaultParamSetters, fmt.Sprintf("%s: %s,", param[0], param[1]))
114+
}
115+
116+
lines := []string{
117+
"package optimizer",
118+
"",
119+
fmt.Sprintf("type %s struct {", object.Name),
120+
"\t" + strings.Join(objectProperties, "\n\t"),
121+
"}",
122+
"",
123+
fmt.Sprintf("func New%s(%s) *%s {", object.Name, getRequiredParamsString(object), object.Name),
124+
fmt.Sprintf(
125+
"\treturn &%s{\n\t\t%s%s\t\n\t}",
126+
object.Name,
127+
strings.Join(requiredParamSetters, "\n\t\t"),
128+
strings.Join(defaultParamSetters, "\n\t\t"),
129+
),
130+
"}",
131+
"",
132+
strings.Join(setters, "\n\n"),
133+
"",
134+
fmt.Sprintf("type jsonConfig%s struct {", object.Name),
135+
"\tClassName string `json:\"class_name\"`",
136+
"\tName string `json:\"name\"`",
137+
"\tConfig map[string]interface{} `json:\"config\"`",
138+
"}",
139+
fmt.Sprintf(
140+
"func (%s *%s) GetKerasLayerConfig() interface{} {",
141+
strings.ToLower(string(object.Name[0])),
142+
object.Name,
143+
),
144+
fmt.Sprintf("\tif %s == nil {", strings.ToLower(string(object.Name[0]))),
145+
"\t\treturn nil",
146+
"\t}",
147+
fmt.Sprintf("\treturn jsonConfig%s{", object.Name),
148+
fmt.Sprintf("\t\tClassName: \"%s\",", object.Config["class_name"]),
149+
fmt.Sprintf("\t\tConfig: %s,", getConfigValue(object)),
150+
"\t}",
151+
"}",
152+
}
153+
154+
e = ioutil.WriteFile(
155+
filepath.Join("../../optimizer", fmt.Sprintf("%s.go", object.Name)),
156+
[]byte(strings.Join(lines, "\n")),
157+
os.ModePerm,
158+
)
159+
if e != nil {
160+
panic(e)
161+
}
54162
}
55163

56164
func createInitializer(object objectJson) {
@@ -61,14 +169,26 @@ func createInitializer(object objectJson) {
61169

62170
var setters []string
63171
var objectProperties []string
172+
objectPropertyNames := make(map[string]bool)
64173
for _, param := range getRequiredParams(object) {
174+
objectPropertyNames[param[0]] = true
65175
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", param[0], param[1]))
66176
}
67177
for _, param := range getOptionalParams(object) {
178+
objectPropertyNames[param[0]] = true
68179
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", param[0], param[1]))
69180
setters = append(setters, getOptionString(object.Name, param[0], param[1]))
70181
}
71182

183+
subConfig := object.Config["config"].(map[string]interface{})
184+
for originalName, value := range subConfig {
185+
name := snakeCaseToCamelCase(originalName)
186+
if _, ok := objectPropertyNames[name]; ok {
187+
continue
188+
}
189+
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", snakeCaseToCamelCase(name), getGolangTypeFromValue(object.Name, originalName, value)))
190+
}
191+
72192
var requiredParamSetters []string
73193
for _, paramName := range getRequiredParamNames(object) {
74194
requiredParamSetters = append(requiredParamSetters, fmt.Sprintf("%s: %s", paramName, paramName))
@@ -135,14 +255,26 @@ func createRegularizer(object objectJson) {
135255

136256
var setters []string
137257
var objectProperties []string
258+
objectPropertyNames := make(map[string]bool)
138259
for _, param := range getRequiredParams(object) {
260+
objectPropertyNames[param[0]] = true
139261
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", param[0], param[1]))
140262
}
141263
for _, param := range getOptionalParams(object) {
264+
objectPropertyNames[param[0]] = true
142265
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", param[0], param[1]))
143266
setters = append(setters, getOptionString(object.Name, param[0], param[1]))
144267
}
145268

269+
subConfig := object.Config["config"].(map[string]interface{})
270+
for originalName, value := range subConfig {
271+
name := snakeCaseToCamelCase(originalName)
272+
if _, ok := objectPropertyNames[name]; ok {
273+
continue
274+
}
275+
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", snakeCaseToCamelCase(name), getGolangTypeFromValue(object.Name, originalName, value)))
276+
}
277+
146278
var requiredParamSetters []string
147279
for _, paramName := range getRequiredParamNames(object) {
148280
requiredParamSetters = append(requiredParamSetters, fmt.Sprintf("%s: %s", paramName, paramName))
@@ -209,14 +341,26 @@ func createConstraint(object objectJson) {
209341

210342
var setters []string
211343
var objectProperties []string
344+
objectPropertyNames := make(map[string]bool)
212345
for _, param := range getRequiredParams(object) {
346+
objectPropertyNames[param[0]] = true
213347
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", param[0], param[1]))
214348
}
215349
for _, param := range getOptionalParams(object) {
350+
objectPropertyNames[param[0]] = true
216351
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", param[0], param[1]))
217352
setters = append(setters, getOptionString(object.Name, param[0], param[1]))
218353
}
219354

355+
subConfig := object.Config["config"].(map[string]interface{})
356+
for originalName, value := range subConfig {
357+
name := snakeCaseToCamelCase(originalName)
358+
if _, ok := objectPropertyNames[name]; ok {
359+
continue
360+
}
361+
objectProperties = append(objectProperties, fmt.Sprintf("%s %s", snakeCaseToCamelCase(name), getGolangTypeFromValue(object.Name, originalName, value)))
362+
}
363+
220364
var requiredParamSetters []string
221365
for _, paramName := range getRequiredParamNames(object) {
222366
requiredParamSetters = append(requiredParamSetters, fmt.Sprintf("%s: %s", paramName, paramName))

generate/python/generate_keras_objects.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@
112112
{"type": "layer", "class": k.layers.experimental.EinsumDense, "args": []},
113113
{"type": "layer", "class": k.layers.experimental.RandomFourierFeatures, "args": []},
114114
{"type": "layer", "class": k.layers.experimental.SyncBatchNormalization, "args": []},
115+
{"type": "optimizer", "class": k.optimizers.Adagrad, "args": []},
116+
{"type": "optimizer", "class": k.optimizers.Adadelta, "args": []},
117+
{"type": "optimizer", "class": k.optimizers.Adam, "args": []},
118+
{"type": "optimizer", "class": k.optimizers.Adamax, "args": []},
119+
{"type": "optimizer", "class": k.optimizers.Ftrl, "args": []},
120+
{"type": "optimizer", "class": k.optimizers.Nadam, "args": []},
121+
{"type": "optimizer", "class": k.optimizers.RMSprop, "args": []},
122+
{"type": "optimizer", "class": k.optimizers.SGD, "args": []},
115123
]
116124
defaults = {
117125
"activation": "linear",

0 commit comments

Comments
 (0)