Skip to content

Commit 932c829

Browse files
committed
global premodulo for tally, allow export the boolean datamap
1 parent 5545bc8 commit 932c829

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

datasets/tally.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package datasets
22

33
import "sync"
4+
import "crypto/rand"
5+
import "encoding/binary"
46

57
// Tally is used to count votes on dataset features and return the majority votes
68
type Tally struct {
@@ -26,6 +28,9 @@ type Tally struct {
2628

2729
// improvementPossible reports whether an improvement is possible
2830
improvementPossible bool
31+
32+
// global premodulo and salt
33+
globalPremodulo, globalSalt uint32
2934
}
3035

3136
// Init initializes the tally dataset structure
@@ -34,13 +39,30 @@ func (t *Tally) Init() {
3439
t.correct = make(map[uint32]int64)
3540
t.improve = make(map[uint32]int64)
3641
}
42+
3743
// Free frees the memory occupied by tally dataset structure
3844
func (t *Tally) Free() {
3945
t.mapping = nil
4046
t.correct = nil
4147
t.improve = nil
4248
}
4349

50+
func (t *Tally) IsGlobalPremodulo() bool {
51+
return t.globalPremodulo != 0
52+
}
53+
func (t *Tally) SetGlobalPremodulo(mod uint32) {
54+
var b [4]byte
55+
rand.Read(b[:])
56+
t.globalSalt = binary.LittleEndian.Uint32(b[:])
57+
t.globalPremodulo = mod
58+
}
59+
func (t *Tally) GetGlobalSaltPremodulo() [2]uint32 {
60+
return [2]uint32{t.globalSalt, t.globalPremodulo}
61+
}
62+
func (t *Tally) GetGlobalPremodulo() uint32 {
63+
return t.globalPremodulo
64+
}
65+
4466
// SetFinalization sets isFinalization and enables the final stage of training
4567
func (t *Tally) SetFinalization(final bool) {
4668
t.isFinalization = final
@@ -52,6 +74,7 @@ func (t *Tally) GetImprovementPossible() bool {
5274
defer t.mut.Unlock()
5375
return t.improvementPossible
5476
}
77+
5578
// Len estimates the size of tally
5679
func (t *Tally) Len() (o int) {
5780
t.mut.Lock()
@@ -64,6 +87,7 @@ func (t *Tally) Len() (o int) {
6487
t.mut.Unlock()
6588
return
6689
}
90+
6791
// Improve votes for feature which improved the overall result
6892
func (t *Tally) AddToImprove(feature uint32, vote int8) {
6993
if vote == 0 {
@@ -102,7 +126,7 @@ func (t *Tally) AddToMapAll(feature uint16, output uint64, loss func(n uint32) u
102126
if t.mapping[feature] == nil {
103127
t.mapping[feature] = make(map[uint64]uint64)
104128
}
105-
t.mapping[feature][output] ++
129+
t.mapping[feature][output]++
106130
t.improvementPossible = true
107131
t.mut.Unlock()
108132
}
@@ -113,7 +137,7 @@ func (t *Tally) AddToMapping(feature uint16, output uint64) {
113137
if t.mapping[feature] == nil {
114138
t.mapping[feature] = make(map[uint64]uint64)
115139
}
116-
t.mapping[feature][output] ++
140+
t.mapping[feature][output]++
117141
t.improvementPossible = true
118142
t.mut.Unlock()
119143
}
@@ -155,3 +179,22 @@ func (t *Tally) Split() SplittedDataset {
155179
return sett.Split()
156180
}
157181
}
182+
183+
// Dataset gets binary Dataset from tally
184+
func (t *Tally) Dataset() Dataset {
185+
var sett Dataset
186+
sett.Init()
187+
// we initialize the set with pairs which improve first
188+
for value, rating := range t.improve {
189+
if rating != 0 {
190+
sett[value] = rating > 0
191+
}
192+
}
193+
// finally we overwrite the set with pairs which make it correct
194+
for value, rating := range t.correct {
195+
if rating != 0 {
196+
sett[value] = rating > 0
197+
}
198+
}
199+
return sett
200+
}

0 commit comments

Comments
 (0)