Skip to content

Commit 36f9bd0

Browse files
committed
add phonemizers (transformer) and their dataset
1 parent 39c1b0e commit 36f9bd0

File tree

4 files changed

+159
-169
lines changed

4 files changed

+159
-169
lines changed

cmd/train_phonemizer/main.go

Lines changed: 42 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ import "time"
1111

1212
import "github.com/neurlang/classifier/datasets/phonemizer"
1313
//import "github.com/neurlang/classifier/layer/majpool2d"
14-
//import "github.com/neurlang/classifier/layer/sum"
15-
//import "github.com/neurlang/classifier/layer/sochastic"
16-
import "github.com/neurlang/classifier/layer/parity"
14+
import "github.com/neurlang/classifier/layer/sum"
15+
import "github.com/neurlang/classifier/layer/sochastic"
16+
//import "github.com/neurlang/classifier/layer/parity"
17+
import "github.com/neurlang/classifier/layer/crossattention"
1718
import "github.com/neurlang/classifier/datasets"
1819
import "github.com/neurlang/classifier/hashtron"
1920
//import "github.com/neurlang/classifier/learning"
@@ -31,9 +32,13 @@ func error_abs(a, b uint32) uint32 {
3132
func main() {
3233
cleantsv := flag.String("cleantsv", "", "clean tsv dataset for the language")
3334
premodulo := flag.Int("premodulo", 0, "premodulo")
35+
minpremodulo := flag.Int("minpremodulo", 0, "minpremodulo")
36+
maxpremodulo := flag.Int("maxpremodulo", 0, "maxpremodulo")
37+
maxdepth := flag.Int("maxdepth", 0, "max training depth")
3438
part := flag.Int("part", 0, "train on n/part-th")
3539
dstmodel := flag.String("dstmodel", "", "model destination .json.lzw file")
3640
flag.Bool("pgo", false, "enable pgo")
41+
boosting := flag.Bool("padspace", false, "enable padspace")
3742
resume := flag.Bool("resume", false, "resume training")
3843
flag.Parse()
3944

@@ -43,58 +48,35 @@ func main() {
4348
println("clean tsv is mandatory")
4449
return
4550
}
51+
if maxdepth == nil || *maxdepth == 0 {
52+
println("max depth is mandatory")
53+
return
54+
}
4655

47-
data := phonemizer.Split(phonemizer.NewDataset(*cleantsv))
56+
data := phonemizer.Split(phonemizer.NewDataset(*cleantsv, boosting != nil && *boosting))
4857

4958
if len(data) == 0 {
5059
println("it looks like no data for this language, or language is unambiguous (no model needed)")
5160
return
5261
}
53-
/*
54-
55-
56-
57-
net.NewLayer(fanout1*fanout2*fanout3, 0)
58-
net.NewCombiner(sochastic.MustNew(fanout1*fanout2*fanout3, 32, 1))
59-
net.NewLayer(fanout1*fanout2, 0)
6062

61-
*/
62-
const fanout1 = 5
63-
var net feedforward.FeedforwardNetwork
64-
//net.NewLayer(fanout1, 0)
65-
//net.NewCombiner(sochastic.MustNew(fanout1, 32, 0))
66-
net.NewLayer(fanout1, 0)
67-
net.NewCombiner(parity.MustNew(fanout1))
68-
net.NewLayer(1, 0)
69-
/*
70-
net.NewCombiner(sochastic.MustNew(1, 32, 0))
71-
net.NewLayerPI(1, 0, 0)
72-
net.NewCombiner(sochastic.MustNew(1, 32, 0))
73-
net.NewLayerPI(1, 0, 0)
74-
net.NewCombiner(sochastic.MustNew(1, 1, 0))
75-
*/
76-
/*
77-
const fanout1 = 1
78-
const fanout2 = 5
63+
const fanout1 = 32
64+
const fanout2 = 4
7965
const fanout3 = 3
80-
const fanout4 = 5
81-
//const fanout5 = 1
82-
//const fanout6 = 4
83-
//const fanout7 = 1
84-
//const fanout8 = 5
85-
66+
8667
var net feedforward.FeedforwardNetwork
87-
//net.NewLayerP(fanout1*fanout2*fanout3*fanout4*fanout5*fanout6*fanout7*fanout8, 0, 1<<fanout8)
88-
//net.NewCombiner(majpool2d.MustNew2(fanout1*fanout2*fanout3*fanout4*fanout5*fanout6*fanout8, 1, fanout7, 1, fanout8, 1, 1, 0))
89-
//net.NewLayerP(fanout1*fanout2*fanout3*fanout4*fanout5*fanout6, 0, 1<<(fanout6*fanout6*2/3))
90-
//net.NewCombiner(majpool2d.MustNew2(fanout1*fanout2*fanout3*fanout4*fanout6, 1, fanout5, 1, fanout6, 1, 1, 0))
91-
net.NewLayerP(fanout1*fanout2*fanout3*fanout4, 0, 1<<13)
92-
net.NewCombiner(majpool2d.MustNew2(fanout1*fanout2*fanout4, 1, fanout3, 1, fanout4, 1, 1, 0))
9368
net.NewLayer(fanout1*fanout2, 0)
94-
//net.NewCombiner(full.MustNew(fanout2, 1, 1))
95-
net.NewCombiner(majpool2d.MustNew2(fanout2, 1, fanout1, 1, fanout2, 1, 1, 0))
69+
for i := 0; i < fanout3; i++ {
70+
net.NewCombiner(crossattention.MustNew(fanout1, fanout2))
71+
net.NewLayerPI(fanout1*fanout2, 0, 0)
72+
net.NewCombiner(sochastic.MustNew(fanout1*fanout2, 8*byte(i), uint32(i)))
73+
net.NewLayerPI(fanout1*fanout2, 0, 0)
74+
}
75+
net.NewCombiner(sochastic.MustNew(fanout1*fanout2, 32, fanout3))
76+
net.NewLayer(fanout1*fanout2, 0)
77+
net.NewCombiner(sum.MustNew([]uint{fanout1*fanout2}, 0))
9678
net.NewLayer(1, 0)
97-
*/
79+
9880

9981
trainWorst := func(worst int) func() {
10082
var tally = new(datasets.Tally)
@@ -103,7 +85,16 @@ func main() {
10385
if premodulo != nil && *premodulo > 0 {
10486
tally.SetGlobalPremodulo(uint32(*premodulo))
10587
}
106-
88+
if minpremodulo != nil && *minpremodulo > 0 && maxpremodulo != nil && *maxpremodulo > 0 {
89+
const span = 50 * 50
90+
value := (100 - improved_success_rate) * (100 - improved_success_rate)
91+
premodulo := value * ( *minpremodulo - *maxpremodulo ) / span + *maxpremodulo
92+
//println(improved_success_rate, premodulo)
93+
if premodulo < 2 {
94+
premodulo = 2
95+
}
96+
tally.SetGlobalPremodulo(uint32(premodulo))
97+
}
10798
var parts = 1
10899
if part != nil && *part > 1 {
109100
rand.Seed(time.Now().UnixNano())
@@ -113,7 +104,7 @@ func main() {
113104

114105
parallel.ForEach(len(data)/parts, 1000, func(jjj int) {
115106
{
116-
var io = data[jjj].V1()
107+
var io = data[jjj].V2(fanout1)
117108

118109
net.Tally4(io, worst, tally, nil)
119110
}
@@ -122,41 +113,13 @@ func main() {
122113
if !tally.GetImprovementPossible() {
123114
return nil
124115
}
125-
/*
126-
var h learning.HyperParameters
127-
h.Threads = runtime.NumCPU()
128-
h.Factor = 1 // affects the solution size
129-
130-
// shuffle before solving attempts
131-
h.Shuffle = true
132-
h.Seed = true
133-
134-
// restart when stuck
135-
h.DeadlineMs = 1000
136-
h.DeadlineRetry = 10
137116

138-
// affects how fast is the modulo reduced
139-
h.Subtractor = 1
140-
141-
// reduce Backtracking printing on the log
142-
h.Printer = 70
143-
144-
// save any solution to disk
145-
h.InitialLimit = 1000 + 4*tally.Len()
146-
h.EndWhenSolved = true
147-
148-
h.Name = fmt.Sprint(worst)
149-
//h.SetLogger("solutions11.txt")
150-
151-
//h.AvxLanes = 16
152-
//h.AvxSkip = 4
153-
*/
154117
fmt.Println("hashtron position:", worst, "(job size:", tally.Len(), ")")
155118
ptr := net.GetHashtron(worst)
156119
dset := tally.Dataset()
157120
q := quaternary.Make(dset)
158121
var pmod = [][2]uint32{}
159-
if premodulo != nil && *premodulo > 0 {
122+
if (premodulo != nil && *premodulo > 0) || (minpremodulo != nil && *minpremodulo > 0 && maxpremodulo != nil && *maxpremodulo > 0) {
160123
pmod = [][2]uint32{tally.GetGlobalSaltPremodulo()}
161124
}
162125
htron, err := hashtron.New(pmod, ptr.Bits(), []byte(q))
@@ -182,7 +145,7 @@ func main() {
182145
var percent, errsum atomic.Uint64
183146
parallel.ForEach(len(data)/parts, 1000, func(j int) {
184147
{
185-
var io = data[j].V1()
148+
var io = data[j].V2(fanout1)
186149

187150
var predicted = net.Infer2(io) & 1
188151

@@ -206,7 +169,7 @@ func main() {
206169

207170
if dstmodel != nil && len(*dstmodel) > 0 && improved_success_rate < success {
208171
if improved_success_rate > 0 {
209-
model := strings.ReplaceAll(*dstmodel, "weights1", "weights2")
172+
model := strings.ReplaceAll(*dstmodel, "weights1", "weights4")
210173
err := net.WriteZlibWeightsToFile(model)
211174
if err != nil {
212175
println(err.Error())
@@ -222,7 +185,7 @@ func main() {
222185
return success, h.Sum()
223186
}
224187
if resume != nil && *resume && dstmodel != nil {
225-
model := strings.ReplaceAll(*dstmodel, "weights1", "weights2")
188+
model := strings.ReplaceAll(*dstmodel, "weights1", "weights4")
226189

227190
err := net.ReadZlibWeightsFromFile(model)
228191
if err != nil {
@@ -244,7 +207,7 @@ func main() {
244207
if m.Exists(state, shuf[0], byte(success)) {
245208
continue
246209
}
247-
for worst := 0; worst < len(shuf); worst++ {
210+
for worst := 0; worst < len(shuf) && worst < *maxdepth; worst++ {
248211
println("training #", worst, "hastron of", len(shuf), "hashtrons total")
249212
if this_backoff := trainWorst(shuf[worst]); this_backoff != nil {
250213
infloop = -1

0 commit comments

Comments
 (0)