Skip to content

Commit d0c1d51

Browse files
committed
Query Key Value version of cross attention layer
1 parent 36f9bd0 commit d0c1d51

File tree

2 files changed

+46
-15
lines changed

2 files changed

+46
-15
lines changed

layer/crossattention/combiner.go

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,35 @@ func (f *CrossAttention) Put(n int, v bool) {
88
// Feature returns the n-th feature from the combiner. Next layer reads
99
// its inputs using this method for hashtron n in the next layer.
1010
func (f *CrossAttention) Feature(n int) (o uint32) {
11-
/*
12-
iov := (n % 3)
13-
if iov == 2 {
14-
return
15-
}
16-
17-
dim := f.dim
18-
beginhead := (n / dim) * dim
11+
if f.qkv {
12+
iov := n % 3
13+
dim := f.dim
14+
beginhead := (n / dim) * dim
1915

20-
for x := iov ^ 1; x < dim; x += 3 {
21-
others := f.vec[beginhead + x]
22-
value := f.vec[beginhead + x + (1 << iov)]
23-
me := f.vec[n]
24-
if others && me == value {
25-
o++
16+
if iov == 2 {
17+
// Handle the value position
18+
for x := 0; x < dim; x += 3 {
19+
query := f.vec[beginhead + x] // Get the query
20+
key := f.vec[beginhead + x + 1] // Get the key
21+
me := f.vec[n] // Current value
22+
if query && key && me {
23+
o++
24+
}
25+
}
26+
} else {
27+
// Handle query and key positions
28+
for x := 0; x < dim; x += 3 {
29+
others := f.vec[beginhead + x + iov] // Get the query or key based on iov
30+
value := f.vec[beginhead + x + 2] // Get the value
31+
me := f.vec[n] // Current query or key
32+
if others && me == value {
33+
o++
34+
}
35+
}
2636
}
37+
return o
2738
}
28-
*/
39+
2940
iov := n & 1
3041
dim := f.dim
3142
beginhead := (n / dim) * dim

layer/crossattention/layer.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,35 @@ package crossattention
44
import "github.com/neurlang/classifier/layer"
55

66
type CrossAttentionLayer struct {
7+
qkv bool
78
dim int
89
heads int
910
}
1011

1112
type CrossAttention struct {
1213
vec []bool
14+
qkv bool
1315
dim int
1416
heads int
1517
}
1618

19+
// MustNew3 creates a new qkv full layer with size and bits
20+
func MustNew3(dim int, heads int) *CrossAttentionLayer {
21+
o, err := New3(dim, heads)
22+
if err != nil {
23+
panic(err.Error())
24+
}
25+
return o
26+
}
27+
28+
// New3 creates a new qkv full layer with size and bits
29+
func New3(dim int, heads int) (o *CrossAttentionLayer, err error) {
30+
o = new(CrossAttentionLayer)
31+
o.dim = dim
32+
o.heads = heads
33+
o.qkv = true
34+
return
35+
}
1736
// MustNew creates a new full layer with size and bits
1837
func MustNew(dim int, heads int) *CrossAttentionLayer {
1938
o, err := New(dim, heads)
@@ -37,5 +56,6 @@ func (i *CrossAttentionLayer) Lay() layer.Combiner {
3756
o.vec = make([]bool, i.dim * i.heads)
3857
o.dim = i.dim
3958
o.heads = i.heads
59+
o.qkv = i.qkv
4060
return o
4161
}

0 commit comments

Comments
 (0)