@@ -11,9 +11,10 @@ import "time"
1111
1212import "github.com/neurlang/classifier/datasets/phonemizer"
1313//import "github.com/neurlang/classifier/layer/majpool2d"
14- //import "github.com/neurlang/classifier/layer/sum"
15- //import "github.com/neurlang/classifier/layer/sochastic"
16- import "github.com/neurlang/classifier/layer/parity"
14+ import "github.com/neurlang/classifier/layer/sum"
15+ import "github.com/neurlang/classifier/layer/sochastic"
16+ //import "github.com/neurlang/classifier/layer/parity"
17+ import "github.com/neurlang/classifier/layer/crossattention"
1718import "github.com/neurlang/classifier/datasets"
1819import "github.com/neurlang/classifier/hashtron"
1920//import "github.com/neurlang/classifier/learning"
@@ -31,9 +32,13 @@ func error_abs(a, b uint32) uint32 {
3132func main () {
3233cleantsv := flag .String ("cleantsv" , "" , "clean tsv dataset for the language" )
3334premodulo := flag .Int ("premodulo" , 0 , "premodulo" )
35+ minpremodulo := flag .Int ("minpremodulo" , 0 , "minpremodulo" )
36+ maxpremodulo := flag .Int ("maxpremodulo" , 0 , "maxpremodulo" )
37+ maxdepth := flag .Int ("maxdepth" , 0 , "max training depth" )
3438part := flag .Int ("part" , 0 , "train on n/part-th" )
3539dstmodel := flag .String ("dstmodel" , "" , "model destination .json.lzw file" )
3640flag .Bool ("pgo" , false , "enable pgo" )
41+ boosting := flag .Bool ("padspace" , false , "enable padspace" )
3742resume := flag .Bool ("resume" , false , "resume training" )
3843flag .Parse ()
3944
@@ -43,58 +48,35 @@ func main() {
4348println ("clean tsv is mandatory" )
4449return
4550}
51+ if maxdepth == nil || * maxdepth == 0 {
52+ println ("max depth is mandatory" )
53+ return
54+ }
4655
47- data := phonemizer .Split (phonemizer .NewDataset (* cleantsv ))
56+ data := phonemizer .Split (phonemizer .NewDataset (* cleantsv , boosting != nil && * boosting ))
4857
4958if len (data ) == 0 {
5059println ("it looks like no data for this language, or language is unambiguous (no model needed)" )
5160return
5261}
53- /*
54-
55-
56-
57- net.NewLayer(fanout1*fanout2*fanout3, 0)
58- net.NewCombiner(sochastic.MustNew(fanout1*fanout2*fanout3, 32, 1))
59- net.NewLayer(fanout1*fanout2, 0)
6062
61- */
62- const fanout1 = 5
63- var net feedforward.FeedforwardNetwork
64- //net.NewLayer(fanout1, 0)
65- //net.NewCombiner(sochastic.MustNew(fanout1, 32, 0))
66- net .NewLayer (fanout1 , 0 )
67- net .NewCombiner (parity .MustNew (fanout1 ))
68- net .NewLayer (1 , 0 )
69- /*
70- net.NewCombiner(sochastic.MustNew(1, 32, 0))
71- net.NewLayerPI(1, 0, 0)
72- net.NewCombiner(sochastic.MustNew(1, 32, 0))
73- net.NewLayerPI(1, 0, 0)
74- net.NewCombiner(sochastic.MustNew(1, 1, 0))
75- */
76- /*
77- const fanout1 = 1
78- const fanout2 = 5
63+ const fanout1 = 32
64+ const fanout2 = 4
7965const fanout3 = 3
80- const fanout4 = 5
81- //const fanout5 = 1
82- //const fanout6 = 4
83- //const fanout7 = 1
84- //const fanout8 = 5
85-
66+
8667var net feedforward.FeedforwardNetwork
87- //net.NewLayerP(fanout1*fanout2*fanout3*fanout4*fanout5*fanout6*fanout7*fanout8, 0, 1<<fanout8)
88- //net.NewCombiner(majpool2d.MustNew2(fanout1*fanout2*fanout3*fanout4*fanout5*fanout6*fanout8, 1, fanout7, 1, fanout8, 1, 1, 0))
89- //net.NewLayerP(fanout1*fanout2*fanout3*fanout4*fanout5*fanout6, 0, 1<<(fanout6*fanout6*2/3))
90- //net.NewCombiner(majpool2d.MustNew2(fanout1*fanout2*fanout3*fanout4*fanout6, 1, fanout5, 1, fanout6, 1, 1, 0))
91- net.NewLayerP(fanout1*fanout2*fanout3*fanout4, 0, 1<<13)
92- net.NewCombiner(majpool2d.MustNew2(fanout1*fanout2*fanout4, 1, fanout3, 1, fanout4, 1, 1, 0))
9368net .NewLayer (fanout1 * fanout2 , 0 )
94- //net.NewCombiner(full.MustNew(fanout2, 1, 1))
95- net.NewCombiner(majpool2d.MustNew2(fanout2, 1, fanout1, 1, fanout2, 1, 1, 0))
69+ for i := 0 ; i < fanout3 ; i ++ {
70+ net .NewCombiner (crossattention .MustNew (fanout1 , fanout2 ))
71+ net .NewLayerPI (fanout1 * fanout2 , 0 , 0 )
72+ net .NewCombiner (sochastic .MustNew (fanout1 * fanout2 , 8 * byte (i ), uint32 (i )))
73+ net .NewLayerPI (fanout1 * fanout2 , 0 , 0 )
74+ }
75+ net .NewCombiner (sochastic .MustNew (fanout1 * fanout2 , 32 , fanout3 ))
76+ net .NewLayer (fanout1 * fanout2 , 0 )
77+ net .NewCombiner (sum .MustNew ([]uint {fanout1 * fanout2 }, 0 ))
9678net .NewLayer (1 , 0 )
97- */
79+
9880
9981trainWorst := func (worst int ) func () {
10082var tally = new (datasets.Tally )
@@ -103,7 +85,16 @@ func main() {
10385if premodulo != nil && * premodulo > 0 {
10486tally .SetGlobalPremodulo (uint32 (* premodulo ))
10587}
106-
88+ if minpremodulo != nil && * minpremodulo > 0 && maxpremodulo != nil && * maxpremodulo > 0 {
89+ const span = 50 * 50
90+ value := (100 - improved_success_rate ) * (100 - improved_success_rate )
91+ premodulo := value * ( * minpremodulo - * maxpremodulo ) / span + * maxpremodulo
92+ //println(improved_success_rate, premodulo)
93+ if premodulo < 2 {
94+ premodulo = 2
95+ }
96+ tally .SetGlobalPremodulo (uint32 (premodulo ))
97+ }
10798var parts = 1
10899if part != nil && * part > 1 {
109100rand .Seed (time .Now ().UnixNano ())
@@ -113,7 +104,7 @@ func main() {
113104
114105parallel .ForEach (len (data )/ parts , 1000 , func (jjj int ) {
115106{
116- var io = data [jjj ].V1 ( )
107+ var io = data [jjj ].V2 ( fanout1 )
117108
118109net .Tally4 (io , worst , tally , nil )
119110}
@@ -122,41 +113,13 @@ func main() {
122113if ! tally .GetImprovementPossible () {
123114return nil
124115}
125- /*
126- var h learning.HyperParameters
127- h.Threads = runtime.NumCPU()
128- h.Factor = 1 // affects the solution size
129-
130- // shuffle before solving attempts
131- h.Shuffle = true
132- h.Seed = true
133-
134- // restart when stuck
135- h.DeadlineMs = 1000
136- h.DeadlineRetry = 10
137116
138- // affects how fast is the modulo reduced
139- h.Subtractor = 1
140-
141- // reduce Backtracking printing on the log
142- h.Printer = 70
143-
144- // save any solution to disk
145- h.InitialLimit = 1000 + 4*tally.Len()
146- h.EndWhenSolved = true
147-
148- h.Name = fmt.Sprint(worst)
149- //h.SetLogger("solutions11.txt")
150-
151- //h.AvxLanes = 16
152- //h.AvxSkip = 4
153- */
154117fmt .Println ("hashtron position:" , worst , "(job size:" , tally .Len (), ")" )
155118ptr := net .GetHashtron (worst )
156119dset := tally .Dataset ()
157120q := quaternary .Make (dset )
158121var pmod = [][2 ]uint32 {}
159- if premodulo != nil && * premodulo > 0 {
122+ if ( premodulo != nil && * premodulo > 0 ) || ( minpremodulo != nil && * minpremodulo > 0 && maxpremodulo != nil && * maxpremodulo > 0 ) {
160123pmod = [][2 ]uint32 {tally .GetGlobalSaltPremodulo ()}
161124}
162125htron , err := hashtron .New (pmod , ptr .Bits (), []byte (q ))
@@ -182,7 +145,7 @@ func main() {
182145var percent , errsum atomic.Uint64
183146parallel .ForEach (len (data )/ parts , 1000 , func (j int ) {
184147{
185- var io = data [j ].V1 ( )
148+ var io = data [j ].V2 ( fanout1 )
186149
187150var predicted = net .Infer2 (io ) & 1
188151
@@ -206,7 +169,7 @@ func main() {
206169
207170if dstmodel != nil && len (* dstmodel ) > 0 && improved_success_rate < success {
208171if improved_success_rate > 0 {
209- model := strings .ReplaceAll (* dstmodel , "weights1" , "weights2 " )
172+ model := strings .ReplaceAll (* dstmodel , "weights1" , "weights4 " )
210173err := net .WriteZlibWeightsToFile (model )
211174if err != nil {
212175println (err .Error ())
@@ -222,7 +185,7 @@ func main() {
222185return success , h .Sum ()
223186}
224187if resume != nil && * resume && dstmodel != nil {
225- model := strings .ReplaceAll (* dstmodel , "weights1" , "weights2 " )
188+ model := strings .ReplaceAll (* dstmodel , "weights1" , "weights4 " )
226189
227190err := net .ReadZlibWeightsFromFile (model )
228191if err != nil {
@@ -244,7 +207,7 @@ func main() {
244207if m .Exists (state , shuf [0 ], byte (success )) {
245208continue
246209}
247- for worst := 0 ; worst < len (shuf ); worst ++ {
210+ for worst := 0 ; worst < len (shuf ) && worst < * maxdepth ; worst ++ {
248211println ("training #" , worst , "hastron of" , len (shuf ), "hashtrons total" )
249212if this_backoff := trainWorst (shuf [worst ]); this_backoff != nil {
250213infloop = - 1
0 commit comments