Skip to content

Commit f5667aa

Browse files
committed
Add datasets anytally, pretally, add trainer
1 parent 0fbfac8 commit f5667aa

File tree

6 files changed

+457
-0
lines changed

6 files changed

+457
-0
lines changed

datasets/anytally.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package datasets
2+
3+
type AnyTally interface {
4+
5+
// Erase
6+
Init()
7+
8+
// Global Premodulo
9+
IsGlobalPremodulo() bool
10+
SetGlobalPremodulo(mod uint32)
11+
GetGlobalSaltPremodulo() [2]uint32
12+
GetGlobalPremodulo() uint32
13+
14+
// Pre tallying
15+
GetCellDecision(position int, feature uint32) (bool, bool)
16+
SetCellDecision(position int, feature uint32, output bool)
17+
18+
// Tallying
19+
AddToCorrect(feature uint32, vote int8, improvement bool)
20+
AddToImprove(feature uint32, vote int8)
21+
AddToMapping(feature uint16, output uint64)
22+
23+
// Get Dataset at
24+
DatasetAt(n int) Dataset
25+
GetImprovementPossible() bool
26+
27+
// Len
28+
Len() (ret int)
29+
}
30+
31+
type TallyType byte
32+
const PreTallyType TallyType = 2
33+
const FinTallyType TallyType = 1
34+
35+
func NewAnyTally(typ TallyType) AnyTally {
36+
switch typ {
37+
case PreTallyType:
38+
t := &PreTally{}
39+
t.Init()
40+
return t
41+
case FinTallyType:
42+
t := &Tally{}
43+
t.Init()
44+
t.SetFinalization(true)
45+
return t
46+
default:
47+
return nil
48+
}
49+
}

datasets/pretally.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package datasets
2+
3+
import "sync"
4+
import "crypto/rand"
5+
import "encoding/binary"
6+
7+
// PreTally stores distilled decisions for multiple cells with thread-safe access
8+
type PreTally struct {
9+
maps []map[uint32]bool
10+
mutex sync.RWMutex
11+
12+
// global premodulo and salt
13+
globalPremodulo, globalSalt uint32
14+
15+
}
16+
17+
18+
// Init resets pretally to be empty
19+
func (d *PreTally) Init() {
20+
d.mutex.Lock()
21+
defer d.mutex.Unlock()
22+
d.maps = nil
23+
}
24+
25+
26+
func (t *PreTally) IsGlobalPremodulo() bool {
27+
return t.globalPremodulo != 0
28+
}
29+
func (t *PreTally) SetGlobalPremodulo(mod uint32) {
30+
var b [4]byte
31+
rand.Read(b[:])
32+
t.globalSalt = binary.LittleEndian.Uint32(b[:])
33+
t.globalPremodulo = mod
34+
}
35+
func (t *PreTally) GetGlobalSaltPremodulo() [2]uint32 {
36+
return [2]uint32{t.globalSalt, t.globalPremodulo}
37+
}
38+
func (t *PreTally) GetGlobalPremodulo() uint32 {
39+
return t.globalPremodulo
40+
}
41+
42+
// GetCellDecision returns the distilled output for a specific cell and feature
43+
func (d *PreTally) GetCellDecision(position int, feature uint32) (bool, bool) {
44+
if position < 0 {
45+
return false, false
46+
}
47+
48+
d.mutex.RLock()
49+
defer d.mutex.RUnlock()
50+
51+
if position >= len(d.maps) {
52+
return false, false
53+
}
54+
55+
56+
val, exists := d.maps[position][feature]
57+
return val, exists
58+
}
59+
60+
// SetCellDecision stores a distilled decision for a specific cell and feature
61+
func (d *PreTally) SetCellDecision(position int, feature uint32, output bool) {
62+
if position < 0 {
63+
return
64+
}
65+
66+
d.mutex.Lock()
67+
defer d.mutex.Unlock()
68+
69+
for position >= len(d.maps) {
70+
d.maps = append(d.maps, make(map[uint32]bool))
71+
}
72+
73+
d.maps[position][feature] = output
74+
}
75+
76+
func (d *PreTally) Len() (ret int) {
77+
d.mutex.RLock()
78+
defer d.mutex.RUnlock()
79+
80+
for _, m := range d.maps {
81+
ret += len(m)
82+
}
83+
return
84+
}
85+
func (d *PreTally) Free() {
86+
d.mutex.Lock()
87+
defer d.mutex.Unlock()
88+
89+
d.maps = nil
90+
}
91+
func (d *PreTally) DatasetAt(position int) Dataset {
92+
if position < 0 {
93+
return nil
94+
}
95+
d.mutex.RLock()
96+
defer d.mutex.RUnlock()
97+
if position >= len(d.maps) {
98+
return nil
99+
}
100+
return d.maps[position]
101+
}
102+
103+
// GetImprovementPossible reads improvementPossible
104+
func (t *PreTally) GetImprovementPossible() bool {
105+
t.mutex.RLock()
106+
defer t.mutex.RUnlock()
107+
for _, m := range t.maps {
108+
if len(m) > 0 {
109+
return true
110+
}
111+
}
112+
return false
113+
}
114+
// AddToCorrect votes for feature which caused the overall result to be correct
115+
func (t *PreTally) AddToCorrect(feature uint32, vote int8, improvement bool) {
116+
t.mutex.Lock()
117+
defer t.mutex.Unlock()
118+
119+
if len(t.maps) == 0 {
120+
t.maps = append(t.maps, make(map[uint32]bool))
121+
}
122+
t.maps[0][feature] = vote > 0
123+
}
124+
// AddToImprove votes for feature which caused the overall result to be correct
125+
func (t *PreTally) AddToImprove(feature uint32, vote int8) {
126+
t.mutex.Lock()
127+
defer t.mutex.Unlock()
128+
129+
if len(t.maps) == 0 {
130+
t.maps = append(t.maps, make(map[uint32]bool))
131+
}
132+
t.maps[0][feature] = vote > 0
133+
}
134+
// AddToMapping adds feature maps to this output votes to mapping
135+
func (t *PreTally) AddToMapping(feature uint16, output uint64) {
136+
// not supported
137+
}

trainer/evaluatefunc.go

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package trainer
2+
3+
import "fmt"
4+
import "os"
5+
import "math"
6+
7+
import "github.com/neurlang/classifier/parallel"
8+
import "github.com/neurlang/classifier/net/feedforward"
9+
10+
type dummy struct{}
11+
12+
func (d dummy) MustPutUint16(n int, value uint16) {}
13+
func (d dummy) Sum() [32]byte {
14+
return [32]byte{}
15+
}
16+
17+
type EvaluateFuncHasher interface {
18+
MustPutUint16(n int, value uint16)
19+
Sum() [32]byte
20+
}
21+
22+
// sampleSize calculates the statistically sufficient sample size
23+
// for a given dataset size N and significance level (0–100).
24+
func sampleSize(N int, significance byte) int {
25+
26+
// Convert significance level to Z-score
27+
z := zScoreFromAlpha(100 - significance)
28+
29+
// Assume worst-case proportion p = 0.5 for max variability
30+
p := 0.5
31+
e := float64(100 - significance) // Margin of error = 5%
32+
33+
numerator := math.Pow(z, 2) * p * (1 - p)
34+
denominator := math.Pow(e, 2)
35+
36+
// Initial sample size without population correction
37+
ss := numerator / denominator
38+
39+
// Apply finite population correction
40+
correctedSS := ss * float64(N) / (float64(N) - 1 + ss)
41+
42+
if int(correctedSS) > N {
43+
return N
44+
}
45+
46+
return int(correctedSS)
47+
}
48+
49+
// zScoreFromAlpha returns the Z-score for a given alpha level
50+
// Common: 90% => 1.645, 95% => 1.96, 99% => 2.576
51+
func zScoreFromAlpha(alpha byte) float64 {
52+
switch {
53+
case alpha <= 1:
54+
return 2.576 // 99% confidence
55+
case alpha <= 5:
56+
return 1.96 // 95% confidence
57+
case alpha <= 10:
58+
return 1.645 // 90% confidence
59+
default:
60+
return 1.96 // default fallback
61+
}
62+
}
63+
64+
func NewEvaluateFunc(net feedforward.FeedforwardNetwork, length int, significance byte, succ *int, dstmodel *string,
65+
testFunc func(portion int, h EvaluateFuncHasher) int) func() (int, [32]byte) {
66+
67+
return func() (int, [32]byte) {
68+
var h dummy
69+
var ha EvaluateFuncHasher = h
70+
var success int
71+
if length != 0 {
72+
length = sampleSize(length, significance)
73+
hsh := parallel.NewUint16Hasher(length)
74+
ha = hsh
75+
success = testFunc(length, hsh)
76+
} else {
77+
success = testFunc(0, h)
78+
}
79+
80+
if dstmodel == nil || *dstmodel == "" {
81+
err := net.WriteZlibWeightsToFile("output." + fmt.Sprint(success) + ".json.t.lzw")
82+
if err != nil {
83+
println(err.Error())
84+
}
85+
}
86+
87+
if dstmodel != nil && len(*dstmodel) > 0 && ((succ != nil && (*succ < success || success == 99)) || succ == nil) {
88+
if succ != nil && *succ > 0 {
89+
err := net.WriteZlibWeightsToFile(*dstmodel)
90+
if err != nil {
91+
println(err.Error())
92+
}
93+
}
94+
if succ != nil {
95+
*succ = success
96+
}
97+
} else if dstmodel != nil && len(*dstmodel) > 0 {
98+
if succ != nil {
99+
*succ = success
100+
}
101+
}
102+
103+
if success >= 100 {
104+
println("Max accuracy or wrong data. Exiting")
105+
os.Exit(0)
106+
}
107+
return success, ha.Sum()
108+
}
109+
}

trainer/loopfunc.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package trainer
2+
3+
import "os"
4+
import "fmt"
5+
import "math/rand"
6+
import "time"
7+
8+
import "github.com/neurlang/classifier/net/feedforward"
9+
import "github.com/neurlang/classifier/parallel"
10+
11+
func NewLoopFunc(net feedforward.FeedforwardNetwork, succ *int, treshold int, evaluate func() (int, [32]byte), trainWorst func([]int, int) (undo func())) func() {
12+
13+
var m = parallel.NewMoveSet()
14+
var success, state = evaluate()
15+
var default_backoff = func() {
16+
println("Infinite loop - algorithm stuck in local minimum. Exiting")
17+
os.Exit(0)
18+
}
19+
backoff := default_backoff
20+
var local_minimums = make(map[[32]byte]struct{})
21+
fmt.Printf("%x\n", state)
22+
for {
23+
for infloop := 0; infloop < net.Len(); infloop++ {
24+
var shuf []int
25+
if success < treshold {
26+
shuf = net.Sequence(false)
27+
rand.Seed(time.Now().UnixNano())
28+
rand.Shuffle(len(shuf), func(i, j int) { shuf[i], shuf[j] = shuf[j], shuf[i] })
29+
} else {
30+
shuf = net.Branch(false)
31+
}
32+
if m.Exists(state, shuf[0], byte(success)) {
33+
continue
34+
}
35+
for worst := 0; worst < len(shuf); worst++ {
36+
println("training #", worst, "hastron of", len(shuf), "hashtrons total")
37+
inSucc := success
38+
if succ != nil {
39+
inSucc = *succ
40+
}
41+
worsts := []int{shuf[worst]}
42+
if inSucc < treshold {
43+
if worst+1 < len(shuf) {
44+
worsts = append(worsts, shuf[worst+1])
45+
} else {
46+
break
47+
}
48+
}
49+
if this_backoff := trainWorst(worsts, inSucc); this_backoff != nil {
50+
infloop = -1
51+
this_success, this_state := evaluate()
52+
if _, bad := local_minimums[this_state]; bad {
53+
this_backoff()
54+
break
55+
} else {
56+
backoff, success, state = this_backoff, this_success, this_state
57+
}
58+
} else if worst == 0 {
59+
break
60+
}
61+
fmt.Printf("%x\n", state)
62+
m.Insert(state, shuf[worst], byte(success))
63+
if worst != len(shuf)-1 {
64+
if m.Exists(state, shuf[worst+1], byte(success)) {
65+
break
66+
}
67+
}
68+
}
69+
}
70+
local_minimums[state] = struct{}{}
71+
backoff()
72+
backoff = default_backoff
73+
success, state = evaluate()
74+
}
75+
}

trainer/resume.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package trainer
2+
3+
import "github.com/neurlang/classifier/net/feedforward"
4+
5+
func Resume(net feedforward.FeedforwardNetwork, resume *bool, dstmodel *string) {
6+
if resume != nil && *resume && dstmodel != nil {
7+
err := net.ReadZlibWeightsFromFile(*dstmodel)
8+
if err != nil {
9+
println(err.Error())
10+
}
11+
}
12+
}

0 commit comments

Comments
 (0)