|
| 1 | +package phonemizer_multi |
| 2 | + |
| 3 | +import "github.com/jbarham/primegen" |
| 4 | +import ( |
| 5 | +"github.com/neurlang/classifier/hash" |
| 6 | +"encoding/json" |
| 7 | +"sort" |
| 8 | +"strconv" |
| 9 | +) |
| 10 | +import ( |
| 11 | +"bufio" |
| 12 | +"fmt" |
| 13 | +"os" |
| 14 | +"strings" |
| 15 | +//"encoding/json" |
| 16 | +) |
| 17 | + |
| 18 | +var Primes []uint32 |
| 19 | + |
| 20 | +func init() { |
| 21 | +var p = primegen.New() |
| 22 | +for i := 0; i < 1024; i++ { |
| 23 | +Primes = append(Primes, uint32(p.Next())) |
| 24 | +} |
| 25 | +//fmt.Println(Primes) |
| 26 | +} |
| 27 | + |
| 28 | +// Sample is one sentence |
| 29 | +type Sample struct { |
| 30 | +Sentence []Token |
| 31 | +} |
| 32 | + |
| 33 | +type Token struct { |
| 34 | +// homograph = hash of written word == query |
| 35 | +Homograph uint32 |
| 36 | +// solution = hash of ipa word == value |
| 37 | +Solution uint32 |
| 38 | +// here the fisrt integer is like solution (hash of ipa word), the second is the tag key |
| 39 | +Choices [][2]uint32 |
| 40 | +} |
| 41 | + |
| 42 | +func (t *Token) Len() int { |
| 43 | +return len(t.Choices) |
| 44 | +} |
| 45 | + |
| 46 | +func (s *Sample) V1(dim, pos int) SampleSentence { |
| 47 | +return SampleSentence{ |
| 48 | +Sample: s, |
| 49 | +position: pos, |
| 50 | +dimension: dim, |
| 51 | +} |
| 52 | +} |
| 53 | + |
| 54 | +type SampleSentence struct { |
| 55 | +Sample *Sample |
| 56 | +position int |
| 57 | +dimension int |
| 58 | +} |
| 59 | + |
| 60 | +func (s *SampleSentence) Len() int { |
| 61 | +if len(s.Sample.Sentence) > s.position { |
| 62 | +return s.Sample.Sentence[s.position].Len() |
| 63 | +} |
| 64 | +return 0 |
| 65 | +} |
| 66 | + |
| 67 | +type SampleSentenceIO struct { |
| 68 | +SampleSentence *SampleSentence |
| 69 | +choice int |
| 70 | +} |
| 71 | + |
| 72 | +func (s *SampleSentence) IO(n int) (ret *SampleSentenceIO) { |
| 73 | +return &SampleSentenceIO{ |
| 74 | +SampleSentence: s, |
| 75 | +choice: n, |
| 76 | +} |
| 77 | +} |
| 78 | + |
| 79 | +// Feature: calculates query, key, value input for attention matrix |
| 80 | +// n - if dividible by 3, it's supposed to return the homograph |
| 81 | +// n - if equal to 1 divided by 3, it calculates the key token |
| 82 | +// n - if equal to 2 divided by 3, it calculates the value token |
| 83 | +func (s *SampleSentenceIO) Feature(n int) (ret uint32) { |
| 84 | +pos := (n / 3) % (s.SampleSentence.dimension / 3) |
| 85 | +if s.Parity() == 1 { |
| 86 | +ret = 1 << 31 |
| 87 | +} |
| 88 | +if n % 3 == 0 { |
| 89 | +for ; pos < len((s.SampleSentence.Sample.Sentence)); pos += (s.SampleSentence.dimension/3) { |
| 90 | +ret += uint32(s.SampleSentence.Sample.Sentence[pos].Homograph) + Primes[pos] |
| 91 | +} |
| 92 | +return |
| 93 | + |
| 94 | +} |
| 95 | +for ; pos < len((s.SampleSentence.Sample.Sentence)); pos += (s.SampleSentence.dimension/3) { |
| 96 | +if pos < s.SampleSentence.position { |
| 97 | +ret += uint32(s.SampleSentence.Sample.Sentence[pos].Solution) + Primes[pos] |
| 98 | +} else if pos == s.SampleSentence.position { |
| 99 | +choice := s.SampleSentence.Sample.Sentence[pos].Choices[s.choice] |
| 100 | +// Compare current choice with context |
| 101 | +if n%3 == 1 { |
| 102 | +ret += uint32(choice[1]) // Key |
| 103 | +} else if n%3 == 2 { |
| 104 | +ret += uint32(choice[0]) // Value |
| 105 | +} |
| 106 | +ret += Primes[pos] |
| 107 | +} else { |
| 108 | +ret += Primes[pos] |
| 109 | +} |
| 110 | +} |
| 111 | +return |
| 112 | +} |
| 113 | + |
| 114 | +func (s *SampleSentenceIO) Parity() (ret uint16) { |
| 115 | +return 0 |
| 116 | +} |
| 117 | +func (s *SampleSentenceIO) Output() (ret uint16) { |
| 118 | +if (s.SampleSentence.Sample.Sentence[s.SampleSentence.position].Choices[s.choice][0] == s.SampleSentence.Sample.Sentence[s.SampleSentence.position].Solution) { |
| 119 | +return 1 |
| 120 | +} |
| 121 | +return 0 |
| 122 | +} |
| 123 | + |
| 124 | +func loop(filename string, do func(string, string, string)) { |
| 125 | +// Open the file |
| 126 | +file, err := os.Open(filename) |
| 127 | +if err != nil { |
| 128 | +fmt.Println("Error opening file:", err) |
| 129 | +return |
| 130 | +} |
| 131 | +defer file.Close() |
| 132 | + |
| 133 | +// Create a new scanner to read the file line by line |
| 134 | +scanner := bufio.NewScanner(file) |
| 135 | +for scanner.Scan() { |
| 136 | +line := scanner.Text() |
| 137 | +columns := strings.Split(line, "\t") |
| 138 | + |
| 139 | +// Check if we have exactly two columns |
| 140 | +if len(columns) != 2 && len(columns) != 3 { |
| 141 | +fmt.Println("Line does not have exactly two or three columns:", line) |
| 142 | +continue |
| 143 | +} |
| 144 | + |
| 145 | +// Process each column |
| 146 | +column1 := columns[0] |
| 147 | +column2 := columns[1] |
| 148 | +var column3 string |
| 149 | +if len(columns) > 2 { |
| 150 | +column3 = columns[2] |
| 151 | +} |
| 152 | + |
| 153 | +// Example: Print the columns |
| 154 | +do(column1, column2, column3) |
| 155 | + |
| 156 | +} |
| 157 | + |
| 158 | +// Check for any scanner errors |
| 159 | +if err := scanner.Err(); err != nil { |
| 160 | +fmt.Println("Error reading file:", err) |
| 161 | +} |
| 162 | +} |
| 163 | + |
| 164 | + |
| 165 | +func addTags(bag map[uint32]string, tags ...string) map[uint32]string { |
| 166 | +for _, v := range tags { |
| 167 | +bag[hash.StringHash(0, v)] = v |
| 168 | +} |
| 169 | +return bag |
| 170 | +} |
| 171 | + |
| 172 | +func parseTags(cell string) (ret map[uint32]string) { |
| 173 | +ret = make(map[uint32]string) |
| 174 | +if cell == "" { |
| 175 | +return |
| 176 | +} |
| 177 | +var tags []string |
| 178 | +err := json.Unmarshal([]byte(cell), &tags) |
| 179 | +if err != nil { |
| 180 | +fmt.Printf("Cell tag: %s, Error: %v\n", cell, err) |
| 181 | +} |
| 182 | +for _, v := range tags { |
| 183 | +ret[hash.StringHash(0, v)] = v |
| 184 | +} |
| 185 | +return |
| 186 | +} |
| 187 | + |
| 188 | +func serializeTags(tags map[uint32]string) (key uint32, ret string) { |
| 189 | +var tagstrings = []string{} |
| 190 | +for k, v := range tags { |
| 191 | +key ^= k |
| 192 | +tagstrings = append(tagstrings, v) |
| 193 | +} |
| 194 | +sort.Strings(tagstrings) |
| 195 | +data, _ := json.Marshal(tagstrings) |
| 196 | +if len(data) > 0 { |
| 197 | +ret = string(data) |
| 198 | +} else { |
| 199 | +ret = "[]" |
| 200 | +} |
| 201 | +if key == 0 { |
| 202 | +key++ |
| 203 | +} |
| 204 | +return |
| 205 | +} |
| 206 | + |
| 207 | + |
| 208 | +func NewDataset(dir string) (ret []Sample) { |
| 209 | + |
| 210 | +var tags = make(map[uint32]string) |
| 211 | +var m = make(map[string]map[string]uint32) |
| 212 | + |
| 213 | +loop(dir + string(os.PathSeparator) + "dirty.tsv", func(src string, dst, tag string) { |
| 214 | +if _, ok := m[src]; !ok { |
| 215 | +m[src] = make(map[string]uint32) |
| 216 | +} |
| 217 | +var tagstr = "[]" |
| 218 | +if tag != "" { |
| 219 | +tagstr = tag |
| 220 | +} |
| 221 | +if _, ok := m[src][dst]; !ok { |
| 222 | +var tagkey, tagjson = serializeTags(addTags(parseTags(tagstr), "dict")) |
| 223 | +m[src][dst] = tagkey |
| 224 | +tags[tagkey] = tagjson |
| 225 | +} else { |
| 226 | +existingTags := parseTags(tags[m[src][dst]]) |
| 227 | +var existing []string |
| 228 | +for _, tag := range existingTags { |
| 229 | +existing = append(existing, tag) |
| 230 | +} |
| 231 | +var tagkey, tagjson = serializeTags(addTags(parseTags(tagstr), existing...)) |
| 232 | +m[src][dst] = tagkey |
| 233 | +tags[tagkey] = tagjson |
| 234 | +} |
| 235 | +}) |
| 236 | + |
| 237 | +loop(dir + string(os.PathSeparator) + "multi.tsv", func(src string, dst, _ string) { |
| 238 | +srcv := strings.Split(src, " ") |
| 239 | +dstv := strings.Split(dst, " ") |
| 240 | +if len(srcv) != len(dstv) { |
| 241 | +fmt.Println("Line does not have equal number of words:", src, dst) |
| 242 | +return |
| 243 | +} |
| 244 | +var s Sample |
| 245 | +for i := range srcv { |
| 246 | +var one = srcv[i] == "_" || dstv[i] == "_" |
| 247 | +if !one { |
| 248 | +println("LEXICON:", srcv[i], dstv[i]) |
| 249 | +} |
| 250 | +if len(m[srcv[i]]) == 0 { |
| 251 | +fmt.Println("ERROR: Word not in dict:", srcv[i], dstv[i]) |
| 252 | +t := Token{ |
| 253 | +Homograph: hash.StringHash(0, srcv[i]), |
| 254 | +Solution: 0, |
| 255 | +} |
| 256 | +s.Sentence = append(s.Sentence, t) |
| 257 | +continue |
| 258 | +} |
| 259 | +if len(m[srcv[i]]) == 1 != one { |
| 260 | +fmt.Println("ERROR: Word does not have one spoken form:", srcv[i], dstv[i]) |
| 261 | +for k, v := range m[srcv[i]] { |
| 262 | +println(k, v, tags[v]) |
| 263 | +} |
| 264 | +println() |
| 265 | +} |
| 266 | +var strkey [][2]string |
| 267 | +for k, v := range m[srcv[i]] { |
| 268 | +strkey = append(strkey, [2]string{k, fmt.Sprint(v)}) |
| 269 | +} |
| 270 | +sort.SliceStable(strkey, func(i, j int) bool { |
| 271 | +return strkey[i][0] < strkey[j][0] |
| 272 | +}) |
| 273 | +var array [][2]uint32 |
| 274 | +for _, v := range strkey { |
| 275 | +num, _ := strconv.Atoi(v[1]) |
| 276 | +array = append(array, [2]uint32{hash.StringHash(0, v[0]), uint32(num)}) |
| 277 | +} |
| 278 | +fmt.Println(array) |
| 279 | +for p := len(array)-1; p >= 0; p-- { |
| 280 | +for q := p-1; q >= 0; q-- { |
| 281 | +if array[p][1] == array[q][1] { |
| 282 | +array[q][1]++ |
| 283 | +} |
| 284 | +}} |
| 285 | +fmt.Println(array) |
| 286 | +sort.SliceStable(array, func(i, j int) bool { |
| 287 | +return array[i][0] < array[j][0] |
| 288 | +}) |
| 289 | +t := Token{ |
| 290 | +Homograph: hash.StringHash(0, srcv[i]), |
| 291 | +Solution: hash.StringHash(0, dstv[i]), |
| 292 | +Choices: array, |
| 293 | +} |
| 294 | +s.Sentence = append(s.Sentence, t) |
| 295 | +} |
| 296 | +fmt.Println(s) |
| 297 | +ret = append(ret, s) |
| 298 | +}) |
| 299 | +return |
| 300 | +} |
0 commit comments