Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
generated/
anchor-go

7 changes: 7 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
gen-metaplex:
go build && rm -rf ./generated/metaplex_token_metadata && \
./anchor-go -program-id metaqbxxUerdq28cj1RbAWkYQm3ybzjb6a8bt518x1s \
-type-id uint8 \
-src idl/metaplex/token-metadata-1.14.0.json \
-dst ./generated/metaplex_token_metadata \
-pkg metaplex-token-metadata
34 changes: 34 additions & 0 deletions gen_testing.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package main

import (
"fmt"

. "github.com/dave/jennifer/jen"
. "github.com/gagliardetto/utilz"
)
Expand Down Expand Up @@ -185,6 +187,9 @@ func genTestWithComplexEnum(tFunGroup *Group, insExportedName string, instructio
variantBlock.Id("params").Dot("AccountMetaSlice").Op("=").Nil()
variantBlock.Id("tmp").Op(":=").New(Id(formatComplexEnumVariantTypeName(enumName, variant.Name)))
variantBlock.Id("fu").Dot("Fuzz").Call(Id("tmp"))
if initCode := genInitializeComplexEnumFields(idl, enumName, variant); !isEmpty(initCode) {
variantBlock.Add(initCode)
}
variantBlock.Id("params").Dot("Set" + exportedArgName).Call(Id("tmp"))

variantBlock.Id("buf").Op(":=").New(Qual("bytes", "Buffer"))
Expand All @@ -210,3 +215,32 @@ func genTestWithComplexEnum(tFunGroup *Group, insExportedName string, instructio
})
}
}

func isEmpty(code Code) bool {
return fmt.Sprintf("%#v", code) == ""
}

func genInitializeComplexEnumFields(idl IDL, enumName string, variant IdlEnumVariant) Code {
code := Empty()

// Get the variant type definition
if variant.Fields != nil && variant.Fields.IdlEnumFieldsNamed != nil {
for _, field := range *variant.Fields.IdlEnumFieldsNamed {
if isComplexEnum(field.Type) {
fieldName := ToCamel(field.Name)
enumTypeName := field.Type.GetIdlTypeDefined().Defined
interfaceType := idl.Types.GetByName(enumTypeName)

if len(interfaceType.Type.Variants) > 0 {
// Initialize to first variant ("None" variant)
firstVariantName := formatComplexEnumVariantTypeName(enumTypeName, interfaceType.Type.Variants[0].Name)
code.If(Id("tmp").Dot(fieldName).Op("==").Nil()).Block(
Id("tmp").Dot(fieldName).Op("=").Op("&").Id(firstVariantName).Block(),
).Line()
}
}
}
}

return code
}
193 changes: 135 additions & 58 deletions generator.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package main

import (
"fmt"

. "github.com/dave/jennifer/jen"
"github.com/davecgh/go-spew/spew"
bin "github.com/gagliardetto/binary"
Expand Down Expand Up @@ -113,6 +115,11 @@ func genTypeName(idlTypeEnv IdlType) Code {
arr := idlTypeEnv.GetArray()
st.Index(Id(Itoa(arr.Num))).Add(genTypeName(arr.Thing))
}
case idlTypeEnv.IsHashMap():
{
hashMap := idlTypeEnv.GetHashMap()
st.Map(genTypeName(hashMap.Key)).Add(genTypeName(hashMap.Value))
}
default:
panic(spew.Sdump(idlTypeEnv))
}
Expand Down Expand Up @@ -295,18 +302,21 @@ func genTypeDef(idl *IDL, withDiscriminator bool, def IdlTypeDef) Code {
Id(interfaceMethodName).Call(),
).Line().Line()

// Declare the enum variants container (non-exported, used internally)
code.Type().Id(containerName).StructFunc(
func(structGroup *Group) {
structGroup.Id("Enum").Qual(PkgDfuseBinary, "BorshEnum").Tag(map[string]string{
"borsh_enum": "true",
})
isLargeEnum := len(def.Type.Variants) > 8
if !isLargeEnum {
// Declare the enum variants container (non-exported, used internally)
code.Type().Id(containerName).StructFunc(
func(structGroup *Group) {
structGroup.Id("Enum").Qual(PkgDfuseBinary, "BorshEnum").Tag(map[string]string{
"borsh_enum": "true",
})

for _, variant := range def.Type.Variants {
structGroup.Id(ToCamel(variant.Name)).Id(formatComplexEnumVariantTypeName(enumTypeName, variant.Name))
}
},
).Line().Line()
for _, variant := range def.Type.Variants {
structGroup.Id(ToCamel(variant.Name)).Id(formatComplexEnumVariantTypeName(enumTypeName, variant.Name))
}
},
).Line().Line()
}

for _, variant := range def.Type.Variants {
// Name of the variant type if the enum is a complex enum (i.e. enum variants are inline structs):
Expand All @@ -333,8 +343,22 @@ func genTypeDef(idl *IDL, withDiscriminator bool, def IdlTypeDef) Code {
return nil
}())
}
case variant.Fields.IdlEnumFieldsTuple != nil:
// Handle tuple variants - create numbered fields
for i, tupleType := range *variant.Fields.IdlEnumFieldsTuple {
fieldName := fmt.Sprintf("Field%d", i)
structGroup.Id(fieldName).Add(genTypeName(tupleType)).
Add(func() Code {
if tupleType.IsIdlTypeOption() {
return Tag(map[string]string{
"bin": "optional",
})
}
return nil
}())
}
default:
// TODO: handle tuples
// TODO: handle other field types if any
panic("not handled: " + Sdump(variant.Fields))
}
},
Expand Down Expand Up @@ -486,31 +510,66 @@ func genMarshalWithEncoder_struct(

if isComplexEnum(field.Type) {
enumTypeName := field.Type.GetIdlTypeDefined().Defined
interfaceType := idl.Types.GetByName(enumTypeName)
isLargeEnum := len(interfaceType.Type.Variants) > 8

body.BlockFunc(func(argBody *Group) {
argBody.List(Id("tmp")).Op(":=").Id(formatEnumContainerName(enumTypeName)).Block()
argBody.Switch(Id("realvalue").Op(":=").Id("obj").Dot(exportedArgName).Op(".").Parens(Type())).
BlockFunc(func(switchGroup *Group) {
// TODO: maybe it's from idl.Accounts ???
interfaceType := idl.Types.GetByName(enumTypeName)
for variantIndex, variant := range interfaceType.Type.Variants {
variantTypeNameStruct := formatComplexEnumVariantTypeName(enumTypeName, variant.Name)

switchGroup.Case(Op("*").Id(variantTypeNameStruct)).
if isLargeEnum {
// For large enums, encode discriminant and data directly without container
argBody.Switch(Id("realvalue").Op(":=").Id("obj").Dot(exportedArgName).Op(".").Parens(Type())).
BlockFunc(func(switchGroup *Group) {
for variantIndex, variant := range interfaceType.Type.Variants {
variantTypeNameStruct := formatComplexEnumVariantTypeName(enumTypeName, variant.Name)

switchGroup.Case(Op("*").Id(variantTypeNameStruct)).
BlockFunc(func(caseGroup *Group) {
// Write enum discriminant as uint8
caseGroup.Err().Op(":=").Id("encoder").Dot("WriteUint8").Call(Lit(uint8(variantIndex)))
caseGroup.If(Err().Op("!=").Nil()).Block(Return(Err()))
// Write variant data
caseGroup.Err().Op("=").Id("encoder").Dot("Encode").Call(Id("realvalue"))
caseGroup.If(Err().Op("!=").Nil()).Block(Return(Err()))
})
}
// Handle nil case by defaulting to first variant
switchGroup.Default().
BlockFunc(func(caseGroup *Group) {
caseGroup.Id("tmp").Dot("Enum").Op("=").Lit(variantIndex)
caseGroup.Id("tmp").Dot(ToCamel(variant.Name)).Op("=").Op("*").Id("realvalue")
// Write first variant discriminant (0)
caseGroup.Err().Op(":=").Id("encoder").Dot("WriteUint8").Call(Lit(uint8(0)))
caseGroup.If(Err().Op("!=").Nil()).Block(Return(Err()))
// Write empty first variant data
firstVariantName := formatComplexEnumVariantTypeName(enumTypeName, interfaceType.Type.Variants[0].Name)
caseGroup.List(Id("emptyVariant")).Op(":=").Op("&").Id(firstVariantName).Block()
caseGroup.Err().Op("=").Id("encoder").Dot("Encode").Call(Id("emptyVariant"))
caseGroup.If(Err().Op("!=").Nil()).Block(Return(Err()))
})
}
})

argBody.Err().Op(":=").Id("encoder").Dot("Encode").Call(Id("tmp"))

argBody.If(
Err().Op("!=").Nil(),
).Block(
Return(Err()),
)
})
} else {
// Use container struct for smaller enums
argBody.List(Id("tmp")).Op(":=").Id(formatEnumContainerName(enumTypeName)).Block()
argBody.Switch(Id("realvalue").Op(":=").Id("obj").Dot(exportedArgName).Op(".").Parens(Type())).
BlockFunc(func(switchGroup *Group) {
for variantIndex, variant := range interfaceType.Type.Variants {
variantTypeNameStruct := formatComplexEnumVariantTypeName(enumTypeName, variant.Name)

switchGroup.Case(Op("*").Id(variantTypeNameStruct)).
BlockFunc(func(caseGroup *Group) {
caseGroup.Id("tmp").Dot("Enum").Op("=").Lit(variantIndex)
caseGroup.Id("tmp").Dot(ToCamel(variant.Name)).Op("=").Op("*").Id("realvalue")
})
}
// Handle nil case by defaulting to first variant
switchGroup.Default().
BlockFunc(func(caseGroup *Group) {
caseGroup.Id("tmp").Dot("Enum").Op("=").Lit(0)
firstVariantName := ToCamel(interfaceType.Type.Variants[0].Name)
caseGroup.Id("tmp").Dot(firstVariantName).Op("=").Id(formatComplexEnumVariantTypeName(enumTypeName, interfaceType.Type.Variants[0].Name)).Block()
})
})

argBody.Err().Op(":=").Id("encoder").Dot("Encode").Call(Id("tmp"))
argBody.If(Err().Op("!=").Nil()).Block(Return(Err()))
}
})
} else {

Expand Down Expand Up @@ -619,35 +678,53 @@ func genUnmarshalWithDecoder_struct(
}

if isComplexEnum(field.Type) {
// TODO:
enumName := field.Type.GetIdlTypeDefined().Defined
body.BlockFunc(func(argBody *Group) {

argBody.List(Id("tmp")).Op(":=").New(Id(formatEnumContainerName(enumName)))

argBody.Err().Op(":=").Id("decoder").Dot("Decode").Call(Id("tmp"))

argBody.If(
Err().Op("!=").Nil(),
).Block(
Return(Err()),
)
interfaceType := idl.Types.GetByName(enumName)
isLargeEnum := len(interfaceType.Type.Variants) > 8

argBody.Switch(Id("tmp").Dot("Enum")).
BlockFunc(func(switchGroup *Group) {
interfaceType := idl.Types.GetByName(enumName)
for variantIndex, variant := range interfaceType.Type.Variants {
switchGroup.Case(Lit(variantIndex)).
body.BlockFunc(func(argBody *Group) {
if isLargeEnum {
// For large enums, read discriminant and data directly without container
argBody.List(Id("discriminant"), Err()).Op(":=").Id("decoder").Dot("ReadUint8").Call()
argBody.If(Err().Op("!=").Nil()).Block(Return(Err()))

argBody.Switch(Id("discriminant")).
BlockFunc(func(switchGroup *Group) {
for variantIndex, variant := range interfaceType.Type.Variants {
variantTypeNameStruct := formatComplexEnumVariantTypeName(enumName, variant.Name)
switchGroup.Case(Lit(uint8(variantIndex))).
BlockFunc(func(caseGroup *Group) {
caseGroup.List(Id("variant")).Op(":=").New(Id(variantTypeNameStruct))
caseGroup.Err().Op("=").Id("decoder").Dot("Decode").Call(Id("variant"))
caseGroup.If(Err().Op("!=").Nil()).Block(Return(Err()))
caseGroup.Id("obj").Dot(exportedArgName).Op("=").Id("variant")
})
}
switchGroup.Default().
BlockFunc(func(caseGroup *Group) {
caseGroup.Id("obj").Dot(exportedArgName).Op("=").Op("&").Id("tmp").Dot(ToCamel(variant.Name))
caseGroup.Return(Qual("fmt", "Errorf").Call(Lit("unknown enum index: %v"), Id("discriminant")))
})
}
switchGroup.Default().
BlockFunc(func(caseGroup *Group) {
caseGroup.Return(Qual("fmt", "Errorf").Call(Lit("unknown enum index: %v"), Id("tmp").Dot("Enum")))
})
})

})
} else {
// Use container struct for smaller enums
argBody.List(Id("tmp")).Op(":=").New(Id(formatEnumContainerName(enumName)))
argBody.Err().Op(":=").Id("decoder").Dot("Decode").Call(Id("tmp"))
argBody.If(Err().Op("!=").Nil()).Block(Return(Err()))

argBody.Switch(Id("tmp").Dot("Enum")).
BlockFunc(func(switchGroup *Group) {
for variantIndex, variant := range interfaceType.Type.Variants {
switchGroup.Case(Lit(variantIndex)).
BlockFunc(func(caseGroup *Group) {
caseGroup.Id("obj").Dot(exportedArgName).Op("=").Op("&").Id("tmp").Dot(ToCamel(variant.Name))
})
}
switchGroup.Default().
BlockFunc(func(caseGroup *Group) {
caseGroup.Return(Qual("fmt", "Errorf").Call(Lit("unknown enum index: %v"), Id("tmp").Dot("Enum")))
})
})
}
})
} else {

Expand Down
Loading