File tree Expand file tree Collapse file tree 2 files changed +46
-15
lines changed Expand file tree Collapse file tree 2 files changed +46
-15
lines changed Original file line number Diff line number Diff line change @@ -8,24 +8,35 @@ func (f *CrossAttention) Put(n int, v bool) {
88// Feature returns the n-th feature from the combiner. Next layer reads
99// its inputs using this method for hashtron n in the next layer.
1010func (f * CrossAttention ) Feature (n int ) (o uint32 ) {
11- /*
12- iov := (n % 3)
13- if iov == 2 {
14- return
15- }
16-
17- dim := f.dim
18- beginhead := (n / dim) * dim
11+ if f .qkv {
12+ iov := n % 3
13+ dim := f .dim
14+ beginhead := (n / dim ) * dim
1915
20- for x := iov ^ 1; x < dim; x += 3 {
21- others := f.vec[beginhead + x]
22- value := f.vec[beginhead + x + (1 << iov)]
23- me := f.vec[n]
24- if others && me == value {
25- o++
16+ if iov == 2 {
17+ // Handle the value position
18+ for x := 0 ; x < dim ; x += 3 {
19+ query := f .vec [beginhead + x ] // Get the query
20+ key := f .vec [beginhead + x + 1 ] // Get the key
21+ me := f .vec [n ] // Current value
22+ if query && key && me {
23+ o ++
24+ }
25+ }
26+ } else {
27+ // Handle query and key positions
28+ for x := 0 ; x < dim ; x += 3 {
29+ others := f .vec [beginhead + x + iov ] // Get the query or key based on iov
30+ value := f .vec [beginhead + x + 2 ] // Get the value
31+ me := f .vec [n ] // Current query or key
32+ if others && me == value {
33+ o ++
34+ }
35+ }
2636}
37+ return o
2738}
28- */
39+
2940iov := n & 1
3041dim := f .dim
3142beginhead := (n / dim ) * dim
Original file line number Diff line number Diff line change @@ -4,16 +4,35 @@ package crossattention
44import "github.com/neurlang/classifier/layer"
55
66type CrossAttentionLayer struct {
7+ qkv bool
78dim int
89heads int
910}
1011
1112type CrossAttention struct {
1213vec []bool
14+ qkv bool
1315dim int
1416heads int
1517}
1618
19+ // MustNew3 creates a new qkv full layer with size and bits
20+ func MustNew3 (dim int , heads int ) * CrossAttentionLayer {
21+ o , err := New3 (dim , heads )
22+ if err != nil {
23+ panic (err .Error ())
24+ }
25+ return o
26+ }
27+
28+ // New3 creates a new qkv full layer with size and bits
29+ func New3 (dim int , heads int ) (o * CrossAttentionLayer , err error ) {
30+ o = new (CrossAttentionLayer )
31+ o .dim = dim
32+ o .heads = heads
33+ o .qkv = true
34+ return
35+ }
1736// MustNew creates a new full layer with size and bits
1837func MustNew (dim int , heads int ) * CrossAttentionLayer {
1938o , err := New (dim , heads )
@@ -37,5 +56,6 @@ func (i *CrossAttentionLayer) Lay() layer.Combiner {
3756o .vec = make ([]bool , i .dim * i .heads )
3857o .dim = i .dim
3958o .heads = i .heads
59+ o .qkv = i .qkv
4060return o
4161}
You can’t perform that action at this time.
0 commit comments