diff --git a/arm64/instructions.go b/arm64/instructions.go index 2d4f7ee..fbb6bf7 100644 --- a/arm64/instructions.go +++ b/arm64/instructions.go @@ -26,8 +26,12 @@ func (arm64 *Arm64) EndDefine() { arm64.defineMode = false } -func (arm64 *Arm64) CBZ(label string, comment ...string) { - arm64.writeOp(comment, "BLE", label) +func (arm64 *Arm64) CBZ(r interface{}, label Label, comment ...string) { + arm64.writeOp(comment, "CBZ", r, string(label)) +} + +func (arm64 *Arm64) BLE(label Label, comment ...string) { + arm64.writeOp(comment, "BLE", string(label)) } func (arm64 *Arm64) LDP(address string, x, y interface{}, comment ...string) { @@ -39,6 +43,11 @@ func (arm64 *Arm64) LDPP(offset int, src, x, y interface{}, comment ...string) { arm64.writeOp(comment, "LDP.P", src, toTuple(x, y)) } +func (arm64 *Arm64) LDPW(offset int, src, x, y interface{}, comment ...string) { + src = fmt.Sprintf("%d(%s)", offset, Operand(src)) + arm64.writeOp(comment, "LDP.W", src, toTuple(x, y)) +} + func (arm64 *Arm64) STP(x, y interface{}, address string, comment ...string) { arm64.writeOp(comment, "STP", toTuple(x, y), address) //arm64.WriteLn(fmt.Sprintf("STP (R%d, R%d), %s", uint64(x), uint64(y), address)) @@ -60,6 +69,10 @@ func (arm64 *Arm64) ADC(op1, op2, dst interface{}, comment ...string) { arm64.writeOp(comment, "ADC", op1, op2, dst) } +func (arm64 *Arm64) SUB(op1, op2, dst interface{}, comment ...string) { + arm64.writeOp(comment, "SUB", op1, op2, dst) +} + func (arm64 *Arm64) SUBS(subtrahend, minuend, difference interface{}, comment ...string) { arm64.writeOp(comment, "SUBS", subtrahend, minuend, difference) } @@ -76,6 +89,195 @@ func (arm64 *Arm64) MOVD(src, dst interface{}, comment ...string) { arm64.writeOp(comment, "MOVD", src, dst) } +// JMP +func (arm64 *Arm64) JMP(label Label, comment ...string) { + arm64.writeOp(comment, "JMP", string(label)) +} + +// VLD1 +func (arm64 *Arm64) VLD1(offset int, src any, dst VectorRegister, comment ...string) { + src = fmt.Sprintf("%d(%s)", offset, Operand(src)) + arm64.writeOp(comment, "VLD1", src, dst.MemString()) +} + +// VADDV: Add all vector elements to produce a scalar result +func (arm64 *Arm64) VADDV(src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VADDV", src, dst) +} + +func (arm64 *Arm64) VLD1_P(offset int, src any, dst VectorRegister, comment ...string) { + src = fmt.Sprintf("%d(%s)", offset, Operand(src)) + arm64.writeOp(comment, "VLD1.P", src, dst.MemString()) +} + +func (arm64 *Arm64) VLD2_P(offset int, src any, dst1, dst2 VectorRegister, comment ...string) { + src = fmt.Sprintf("%d(%s)", offset, Operand(src)) + dst := VectorRegister(string(dst1) + ", " + string(dst2)) + arm64.writeOp(comment, "VLD2.P", src, dst.MemString()) +} + +// VSHL +func (arm64 *Arm64) VSHL(offset any, src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VSHL", offset, src, dst) +} + +// VMOV +func (arm64 *Arm64) VMOV(src, dst any, comment ...string) { + arm64.writeOp(comment, "VMOV", src, dst) +} + +func (arm64 *Arm64) VMOVI(value any, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VMOVI", value, dst) +} + +func (arm64 *Arm64) VMOVS(value any, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VMOVS", value, dst) +} + +func (arm64 *Arm64) VUSHLL(offset any, src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VUSHLL", offset, src, dst) +} + +func (arm64 *Arm64) VUSHLL2(offset any, src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VUSHLL2", offset, src, dst) +} + +// VDUP +func (arm64 *Arm64) VDUP(src any, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VDUP", src, dst) +} + +// VUSHR +func (arm64 *Arm64) VUSHR(offset any, src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VUSHR", offset, src, dst) +} + +// SHRN +func (arm64 *Arm64) SHRN(immediate any, src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "SHRN", immediate, src, dst) +} + +func (arm64 *Arm64) VUSRA(immediate any, src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VUSRA", immediate, src, dst) +} + +// VST1_P +func (arm64 *Arm64) VST1_P(src VectorRegister, dst interface{}, offset int, comment ...string) { + dst = fmt.Sprintf("%d(%s)", offset, Operand(dst)) + arm64.writeOp(comment, "VST1.P", src.MemString(), dst) +} + +// VST2_P +func (arm64 *Arm64) VST2_P(src1, src2 VectorRegister, dst interface{}, offset int, comment ...string) { + dst = fmt.Sprintf("%d(%s)", offset, Operand(dst)) + src := VectorRegister(string(src1) + ", " + string(src2)) + arm64.writeOp(comment, "VST2.P", src.MemString(), dst) +} + +// VST2 +func (arm64 *Arm64) VST2(src1, src2 VectorRegister, dst interface{}, comment ...string) { + arm64.writeOp(comment, "VST2", src1, src2, dst) +} + +func (arm64 *Arm64) VMOVQ_cst(c1, c2 any, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VMOVQ", c1, c2, dst) +} + +// VEOR +func (arm64 *Arm64) VEOR(op1, op2, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VEOR", op1, op2, dst) +} + +// "VREV16: Reverse byte order within 16-bit half-words.", +// "VREV32: Reverse byte order within 32-bit words.", +// "VREV64: Reverse byte order within 64-bit doublewords.", + +// VREV16 +func (arm64 *Arm64) VREV16(src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VREV16", src, dst) +} + +// VREV32 +func (arm64 *Arm64) VREV32(src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VREV32", src, dst) +} + +// VREV64 +func (arm64 *Arm64) VREV64(src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VREV64", src, dst) +} + +// VORR +func (arm64 *Arm64) VORR(op1, op2, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VORR", op1, op2, dst) +} + +func (arm64 *Arm64) VEXT(n any, src, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VEXT", n, src, dst) +} + +// VADD +func (arm64 *Arm64) VADD(op1, op2, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VADD", op1, op2, dst) +} + +// VUADDW +func (arm64 *Arm64) VUADDW(op1, op2, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VUADDW", op1, op2, dst) +} + +// UMULL +func (arm64 *Arm64) UMULL(op1, op2, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "UMULL", op1, op2, dst) +} + +// VPMULL +func (arm64 *Arm64) VPMULL(op1, op2, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VPMULL", op1, op2, dst) +} + +// VPMULL2 +func (arm64 *Arm64) VPMULL2(op1, op2, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VPMULL2", op1, op2, dst) +} + +// VAND +func (arm64 *Arm64) VAND(op1, op2, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VAND", op1, op2, dst) +} + +// VSUB +func (arm64 *Arm64) VSUB(op1, op2, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VSUB", op1, op2, dst) +} + +// VUMIN +func (arm64 *Arm64) VUMIN(op1, op2, dst VectorRegister, comment ...string) { + arm64.writeOp(comment, "VUMIN", op1, op2, dst) +} + +// MOVWU +func (arm64 *Arm64) MOVWU(src, dst interface{}, comment ...string) { + arm64.writeOp(comment, "MOVWU", src, dst) +} + +// MOVWUP_Load +func (arm64 *Arm64) MOVWUP_Load(offset int, src, dst interface{}, comment ...string) { + src = fmt.Sprintf("%d(%s)", offset, Operand(src)) + arm64.writeOp(comment, "MOVWU.P", src, dst) +} + +// MOVWUP_Store +func (arm64 *Arm64) MOVWUP_Store(src, dst interface{}, offset int, comment ...string) { + dst = fmt.Sprintf("%d(%s)", offset, Operand(dst)) + arm64.writeOp(comment, "MOVWU.P", src, dst) +} + +// MOVW +func (arm64 *Arm64) MOVW(src, dst interface{}, comment ...string) { + arm64.writeOp(comment, "MOVW", src, dst) +} + func (arm64 *Arm64) MOVDP(offset int, src, dst interface{}, comment ...string) { src = fmt.Sprintf("%d(%s)", offset, Operand(src)) arm64.writeOp(comment, "MOVD.P", src, dst) @@ -85,6 +287,10 @@ func (arm64 *Arm64) MUL(op1, op2, dst interface{}, comment ...string) { arm64.writeOp(comment, "MUL", op1, op2, dst) } +func (arm64 *Arm64) MULW(op1, op2, dst interface{}, comment ...string) { + arm64.writeOp(comment, "MULW", op1, op2, dst) +} + func (arm64 *Arm64) UMULH(op1, op2, dst interface{}, comment ...string) { arm64.writeOp(comment, "UMULH", op1, op2, dst) } @@ -101,6 +307,10 @@ func (arm64 *Arm64) CMP(a, b interface{}, comment ...string) { arm64.writeOp(comment, "CMP", a, b) } +func (arm64 *Arm64) BEQ(label string, comment ...string) { + arm64.writeOp(comment, "BEQ", label) +} + func toTuple(x, y interface{}) string { return fmt.Sprintf("(%s, %s)", Operand(x), Operand(y)) } @@ -143,7 +353,7 @@ func (arm64 *Arm64) FnHeader(funcName string, stackSize, argSize int, reserved . } arm64.WriteLn(fmt.Sprintf(header, funcName, stackSize, argSize)) - r := NewRegisters() + r := NewRegisters(arm64) for _, rr := range reserved { r.Remove(rr) } @@ -156,6 +366,8 @@ func Operand(i interface{}) string { return t case Register: return string(t) + case VectorRegister: + return string(t) case int: switch t { case 0: @@ -163,7 +375,7 @@ func Operand(i interface{}) string { case 1: return "$1" default: - return fmt.Sprintf("$%#016x", uint64(t)) + return fmt.Sprintf("$0x%x", t) } case uint64: switch t { @@ -172,7 +384,7 @@ func Operand(i interface{}) string { case 1: return "$1" default: - return fmt.Sprintf("$%#016x", t) + return fmt.Sprintf("$0x%x", t) } } panic("unsupported interface type") diff --git a/arm64/registers.go b/arm64/registers.go index a85561b..f256bf3 100644 --- a/arm64/registers.go +++ b/arm64/registers.go @@ -36,17 +36,111 @@ const ( R29 = Register("R29") ) +const ( + V0 = VectorRegister("V0") + V1 = VectorRegister("V1") + V2 = VectorRegister("V2") + V3 = VectorRegister("V3") + V4 = VectorRegister("V4") + V5 = VectorRegister("V5") + V6 = VectorRegister("V6") + V7 = VectorRegister("V7") + V8 = VectorRegister("V8") + V9 = VectorRegister("V9") + V10 = VectorRegister("V10") + V11 = VectorRegister("V11") + V12 = VectorRegister("V12") + V13 = VectorRegister("V13") + V14 = VectorRegister("V14") + V15 = VectorRegister("V15") + V16 = VectorRegister("V16") + V17 = VectorRegister("V17") + V18 = VectorRegister("V18") + V19 = VectorRegister("V19") + V20 = VectorRegister("V20") + V21 = VectorRegister("V21") + V22 = VectorRegister("V22") + V23 = VectorRegister("V23") + V24 = VectorRegister("V24") + V25 = VectorRegister("V25") + V26 = VectorRegister("V26") + V27 = VectorRegister("V27") + V28 = VectorRegister("V28") + V29 = VectorRegister("V29") + V30 = VectorRegister("V30") + V31 = VectorRegister("V31") +) + // type Label string type Register string +type VectorRegister string + +func (vr VectorRegister) MemString() string { + return "[" + string(vr) + "]" +} + +func (vr VectorRegister) SAt(i int) string { + return fmt.Sprintf("%s.S[%d]", string(vr), i) +} + +// DAt +func (vr VectorRegister) DAt(i int) string { + return fmt.Sprintf("%s.D[%d]", string(vr), i) +} + +func (vr VectorRegister) S4() VectorRegister { + return vr.withSuffix(".S4") +} + +// B8, B16 +func (vr VectorRegister) B8() VectorRegister { + return vr.withSuffix(".B8") +} + +func (vr VectorRegister) B16() VectorRegister { + return vr.withSuffix(".B16") +} + +// H8 +func (vr VectorRegister) H8() VectorRegister { + return vr.withSuffix(".H8") +} + +func (vr VectorRegister) Q1() VectorRegister { + return vr.withSuffix(".Q1") +} + +func (vr VectorRegister) S2() VectorRegister { + return vr.withSuffix(".S2") +} + +func (vr VectorRegister) D1() VectorRegister { + return vr.withSuffix(".D1") +} + +func (vr VectorRegister) D2() VectorRegister { + return vr.withSuffix(".D2") +} + +func (vr VectorRegister) withSuffix(suffix string) VectorRegister { + return VectorRegister(string(vr) + suffix) +} type Registers struct { - registers []Register + registers []Register + vRegisters []VectorRegister + vAliases map[string]VectorRegister + f *Arm64 } func (r *Register) At(wordOffset int) string { return fmt.Sprintf("%d(%s)", wordOffset*8, string(*r)) } +func (r *Register) At2(wordOffset int) string { + return fmt.Sprintf("%d(%s)", wordOffset*4, string(*r)) +} + func (r *Registers) Available() int { return len(r.registers) } @@ -57,6 +151,24 @@ func (r *Registers) Pop() Register { return toReturn } +func (r *Registers) PopV(alias ...string) VectorRegister { + toReturn := r.vRegisters[0] + r.vRegisters = r.vRegisters[1:] + + if len(alias) > 0 { + // check if alias is already used + if _, ok := r.vAliases[alias[0]]; ok { + panic("alias already used") + } + r.vAliases[alias[0]] = toReturn + // write a #define + r.f.WriteLn(fmt.Sprintf("#define %s %s", alias[0], string(toReturn))) + return VectorRegister(alias[0]) + } + + return toReturn +} + func (r *Registers) PopN(n int) []Register { toReturn := make([]Register, n) for i := 0; i < n; i++ { @@ -97,14 +209,81 @@ func (r *Registers) Push(rIn ...Register) { } -func NewRegisters() Registers { +func (r *Registers) PushV(rIn ...VectorRegister) { + // ensure register is in our original list, and no duplicate + for _, register := range rIn { + if _, ok := vRegisterSet[register]; !ok { + // check if it's an alias + realRegister, ok := r.vAliases[string(register)] + if !ok { + panic("warning: unknown register") + } + // remove the alias + delete(r.vAliases, string(register)) + // undef + r.f.WriteLn("#undef " + string(register)) + register = realRegister + } + found := false + for _, existing := range r.vRegisters { + if register == existing { + found = true + break + } + } + if found { + panic("duplicate register, already present.") + } + r.vRegisters = append(r.vRegisters, register) + } + +} + +func NewRegisters(arm64 *Arm64) Registers { r := Registers{ - registers: make([]Register, len(registers)), + registers: make([]Register, len(registers)), + vRegisters: make([]VectorRegister, len(vRegisters)), + vAliases: make(map[string]VectorRegister), + f: arm64, } copy(r.registers, registers) + copy(r.vRegisters, vRegisters) return r } +func (r *Registers) AssertCleanState() { + if len(r.vRegisters) != len(vRegisters) { + // find the ones that are missing for a clear error message + for _, vr := range vRegisters { + found := false + for _, vr2 := range r.vRegisters { + if vr == vr2 { + found = true + break + } + } + if !found { + panic(fmt.Sprintf("missing push vector register %s", vr)) + } + } + } + if len(r.registers) != len(registers) { + // find the ones that are missing for a clear error message + for _, vr := range registers { + found := false + for _, vr2 := range r.registers { + if vr == vr2 { + found = true + break + } + } + if !found { + panic(fmt.Sprintf("missing push register %s", vr)) + } + } + } +} + // NbRegisters contains nb default available registers, without BP const NbRegisters = 27 @@ -138,7 +317,45 @@ var registers = []Register{ R29, // risky. (reserved for FP) } -var registerSet map[Register]struct{} +var vRegisters = []VectorRegister{ + V0, + V1, + V2, + V3, + V4, + V5, + V6, + V7, + V8, + V9, + V10, + V11, + V12, + V13, + V14, + V15, + V16, + V17, + V18, + V19, + V20, + V21, + V22, + V23, + V24, + V25, + V26, + V27, + V28, + V29, + V30, + V31, +} + +var ( + registerSet map[Register]struct{} + vRegisterSet map[VectorRegister]struct{} +) func init() { registerSet = make(map[Register]struct{}, 0) @@ -148,6 +365,12 @@ func init() { if len(registers) != NbRegisters { panic("update nb available registers") } + + vRegisterSet = make(map[VectorRegister]struct{}, 0) + for _, register := range vRegisters { + vRegisterSet[register] = struct{}{} + } + } func (arm64 *Arm64) NewLabel(prefix ...string) Label {