woodpecker-email/vendor/github.com/antonmedv/expr/compiler/compiler.go
2023-01-04 13:11:21 +01:00

674 lines
13 KiB
Go

package compiler
import (
"encoding/binary"
"fmt"
"math"
"reflect"
"github.com/antonmedv/expr/ast"
"github.com/antonmedv/expr/conf"
"github.com/antonmedv/expr/file"
"github.com/antonmedv/expr/parser"
. "github.com/antonmedv/expr/vm"
)
func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%v", r)
}
}()
c := &compiler{
index: make(map[interface{}]uint16),
locations: make(map[int]file.Location),
}
if config != nil {
c.mapEnv = config.MapEnv
c.cast = config.Expect
}
c.compile(tree.Node)
switch c.cast {
case reflect.Int64:
c.emit(OpCast, encode(0)...)
case reflect.Float64:
c.emit(OpCast, encode(1)...)
}
program = &Program{
Source: tree.Source,
Locations: c.locations,
Constants: c.constants,
Bytecode: c.bytecode,
}
return
}
type compiler struct {
locations map[int]file.Location
constants []interface{}
bytecode []byte
index map[interface{}]uint16
mapEnv bool
cast reflect.Kind
nodes []ast.Node
}
func (c *compiler) emit(op byte, b ...byte) int {
c.bytecode = append(c.bytecode, op)
current := len(c.bytecode)
c.bytecode = append(c.bytecode, b...)
var loc file.Location
if len(c.nodes) > 0 {
loc = c.nodes[len(c.nodes)-1].Location()
}
c.locations[current-1] = loc
return current
}
func (c *compiler) emitPush(value interface{}) int {
return c.emit(OpPush, c.makeConstant(value)...)
}
func (c *compiler) makeConstant(i interface{}) []byte {
hashable := true
switch reflect.TypeOf(i).Kind() {
case reflect.Slice, reflect.Map:
hashable = false
}
if hashable {
if p, ok := c.index[i]; ok {
return encode(p)
}
}
c.constants = append(c.constants, i)
if len(c.constants) > math.MaxUint16 {
panic("exceeded constants max space limit")
}
p := uint16(len(c.constants) - 1)
if hashable {
c.index[i] = p
}
return encode(p)
}
func (c *compiler) placeholder() []byte {
return []byte{0xFF, 0xFF}
}
func (c *compiler) patchJump(placeholder int) {
offset := len(c.bytecode) - 2 - placeholder
b := encode(uint16(offset))
c.bytecode[placeholder] = b[0]
c.bytecode[placeholder+1] = b[1]
}
func (c *compiler) calcBackwardJump(to int) []byte {
return encode(uint16(len(c.bytecode) + 1 + 2 - to))
}
func (c *compiler) compile(node ast.Node) {
c.nodes = append(c.nodes, node)
defer func() {
c.nodes = c.nodes[:len(c.nodes)-1]
}()
switch n := node.(type) {
case *ast.NilNode:
c.NilNode(n)
case *ast.IdentifierNode:
c.IdentifierNode(n)
case *ast.IntegerNode:
c.IntegerNode(n)
case *ast.FloatNode:
c.FloatNode(n)
case *ast.BoolNode:
c.BoolNode(n)
case *ast.StringNode:
c.StringNode(n)
case *ast.ConstantNode:
c.ConstantNode(n)
case *ast.UnaryNode:
c.UnaryNode(n)
case *ast.BinaryNode:
c.BinaryNode(n)
case *ast.MatchesNode:
c.MatchesNode(n)
case *ast.PropertyNode:
c.PropertyNode(n)
case *ast.IndexNode:
c.IndexNode(n)
case *ast.SliceNode:
c.SliceNode(n)
case *ast.MethodNode:
c.MethodNode(n)
case *ast.FunctionNode:
c.FunctionNode(n)
case *ast.BuiltinNode:
c.BuiltinNode(n)
case *ast.ClosureNode:
c.ClosureNode(n)
case *ast.PointerNode:
c.PointerNode(n)
case *ast.ConditionalNode:
c.ConditionalNode(n)
case *ast.ArrayNode:
c.ArrayNode(n)
case *ast.MapNode:
c.MapNode(n)
case *ast.PairNode:
c.PairNode(n)
default:
panic(fmt.Sprintf("undefined node type (%T)", node))
}
}
func (c *compiler) NilNode(node *ast.NilNode) {
c.emit(OpNil)
}
func (c *compiler) IdentifierNode(node *ast.IdentifierNode) {
v := c.makeConstant(node.Value)
if c.mapEnv {
c.emit(OpFetchMap, v...)
} else if node.NilSafe {
c.emit(OpFetchNilSafe, v...)
} else {
c.emit(OpFetch, v...)
}
}
func (c *compiler) IntegerNode(node *ast.IntegerNode) {
t := node.Type()
if t == nil {
c.emitPush(node.Value)
return
}
switch t.Kind() {
case reflect.Float32:
c.emitPush(float32(node.Value))
case reflect.Float64:
c.emitPush(float64(node.Value))
case reflect.Int:
c.emitPush(int(node.Value))
case reflect.Int8:
c.emitPush(int8(node.Value))
case reflect.Int16:
c.emitPush(int16(node.Value))
case reflect.Int32:
c.emitPush(int32(node.Value))
case reflect.Int64:
c.emitPush(int64(node.Value))
case reflect.Uint:
c.emitPush(uint(node.Value))
case reflect.Uint8:
c.emitPush(uint8(node.Value))
case reflect.Uint16:
c.emitPush(uint16(node.Value))
case reflect.Uint32:
c.emitPush(uint32(node.Value))
case reflect.Uint64:
c.emitPush(uint64(node.Value))
default:
c.emitPush(node.Value)
}
}
func (c *compiler) FloatNode(node *ast.FloatNode) {
c.emitPush(node.Value)
}
func (c *compiler) BoolNode(node *ast.BoolNode) {
if node.Value {
c.emit(OpTrue)
} else {
c.emit(OpFalse)
}
}
func (c *compiler) StringNode(node *ast.StringNode) {
c.emitPush(node.Value)
}
func (c *compiler) ConstantNode(node *ast.ConstantNode) {
c.emitPush(node.Value)
}
func (c *compiler) UnaryNode(node *ast.UnaryNode) {
c.compile(node.Node)
switch node.Operator {
case "!", "not":
c.emit(OpNot)
case "+":
// Do nothing
case "-":
c.emit(OpNegate)
default:
panic(fmt.Sprintf("unknown operator (%v)", node.Operator))
}
}
func (c *compiler) BinaryNode(node *ast.BinaryNode) {
l := kind(node.Left)
r := kind(node.Right)
switch node.Operator {
case "==":
c.compile(node.Left)
c.compile(node.Right)
if l == r && l == reflect.Int {
c.emit(OpEqualInt)
} else if l == r && l == reflect.String {
c.emit(OpEqualString)
} else {
c.emit(OpEqual)
}
case "!=":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpEqual)
c.emit(OpNot)
case "or", "||":
c.compile(node.Left)
end := c.emit(OpJumpIfTrue, c.placeholder()...)
c.emit(OpPop)
c.compile(node.Right)
c.patchJump(end)
case "and", "&&":
c.compile(node.Left)
end := c.emit(OpJumpIfFalse, c.placeholder()...)
c.emit(OpPop)
c.compile(node.Right)
c.patchJump(end)
case "in":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpIn)
case "not in":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpIn)
c.emit(OpNot)
case "<":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpLess)
case ">":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpMore)
case "<=":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpLessOrEqual)
case ">=":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpMoreOrEqual)
case "+":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpAdd)
case "-":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpSubtract)
case "*":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpMultiply)
case "/":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpDivide)
case "%":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpModulo)
case "**":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpExponent)
case "contains":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpContains)
case "startsWith":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpStartsWith)
case "endsWith":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpEndsWith)
case "..":
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpRange)
default:
panic(fmt.Sprintf("unknown operator (%v)", node.Operator))
}
}
func (c *compiler) MatchesNode(node *ast.MatchesNode) {
if node.Regexp != nil {
c.compile(node.Left)
c.emit(OpMatchesConst, c.makeConstant(node.Regexp)...)
return
}
c.compile(node.Left)
c.compile(node.Right)
c.emit(OpMatches)
}
func (c *compiler) PropertyNode(node *ast.PropertyNode) {
c.compile(node.Node)
if !node.NilSafe {
c.emit(OpProperty, c.makeConstant(node.Property)...)
} else {
c.emit(OpPropertyNilSafe, c.makeConstant(node.Property)...)
}
}
func (c *compiler) IndexNode(node *ast.IndexNode) {
c.compile(node.Node)
c.compile(node.Index)
c.emit(OpIndex)
}
func (c *compiler) SliceNode(node *ast.SliceNode) {
c.compile(node.Node)
if node.To != nil {
c.compile(node.To)
} else {
c.emit(OpLen)
}
if node.From != nil {
c.compile(node.From)
} else {
c.emitPush(0)
}
c.emit(OpSlice)
}
func (c *compiler) MethodNode(node *ast.MethodNode) {
c.compile(node.Node)
for _, arg := range node.Arguments {
c.compile(arg)
}
if !node.NilSafe {
c.emit(OpMethod, c.makeConstant(Call{Name: node.Method, Size: len(node.Arguments)})...)
} else {
c.emit(OpMethodNilSafe, c.makeConstant(Call{Name: node.Method, Size: len(node.Arguments)})...)
}
}
func (c *compiler) FunctionNode(node *ast.FunctionNode) {
for _, arg := range node.Arguments {
c.compile(arg)
}
op := OpCall
if node.Fast {
op = OpCallFast
}
c.emit(op, c.makeConstant(Call{Name: node.Name, Size: len(node.Arguments)})...)
}
func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
switch node.Name {
case "len":
c.compile(node.Arguments[0])
c.emit(OpLen)
c.emit(OpRot)
c.emit(OpPop)
case "all":
c.compile(node.Arguments[0])
c.emit(OpBegin)
var loopBreak int
c.emitLoop(func() {
c.compile(node.Arguments[1])
loopBreak = c.emit(OpJumpIfFalse, c.placeholder()...)
c.emit(OpPop)
})
c.emit(OpTrue)
c.patchJump(loopBreak)
c.emit(OpEnd)
case "none":
c.compile(node.Arguments[0])
c.emit(OpBegin)
var loopBreak int
c.emitLoop(func() {
c.compile(node.Arguments[1])
c.emit(OpNot)
loopBreak = c.emit(OpJumpIfFalse, c.placeholder()...)
c.emit(OpPop)
})
c.emit(OpTrue)
c.patchJump(loopBreak)
c.emit(OpEnd)
case "any":
c.compile(node.Arguments[0])
c.emit(OpBegin)
var loopBreak int
c.emitLoop(func() {
c.compile(node.Arguments[1])
loopBreak = c.emit(OpJumpIfTrue, c.placeholder()...)
c.emit(OpPop)
})
c.emit(OpFalse)
c.patchJump(loopBreak)
c.emit(OpEnd)
case "one":
count := c.makeConstant("count")
c.compile(node.Arguments[0])
c.emit(OpBegin)
c.emitPush(0)
c.emit(OpStore, count...)
c.emitLoop(func() {
c.compile(node.Arguments[1])
c.emitCond(func() {
c.emit(OpInc, count...)
})
})
c.emit(OpLoad, count...)
c.emitPush(1)
c.emit(OpEqual)
c.emit(OpEnd)
case "filter":
count := c.makeConstant("count")
c.compile(node.Arguments[0])
c.emit(OpBegin)
c.emitPush(0)
c.emit(OpStore, count...)
c.emitLoop(func() {
c.compile(node.Arguments[1])
c.emitCond(func() {
c.emit(OpInc, count...)
c.emit(OpLoad, c.makeConstant("array")...)
c.emit(OpLoad, c.makeConstant("i")...)
c.emit(OpIndex)
})
})
c.emit(OpLoad, count...)
c.emit(OpEnd)
c.emit(OpArray)
case "map":
c.compile(node.Arguments[0])
c.emit(OpBegin)
size := c.emitLoop(func() {
c.compile(node.Arguments[1])
})
c.emit(OpLoad, size...)
c.emit(OpEnd)
c.emit(OpArray)
case "count":
count := c.makeConstant("count")
c.compile(node.Arguments[0])
c.emit(OpBegin)
c.emitPush(0)
c.emit(OpStore, count...)
c.emitLoop(func() {
c.compile(node.Arguments[1])
c.emitCond(func() {
c.emit(OpInc, count...)
})
})
c.emit(OpLoad, count...)
c.emit(OpEnd)
default:
panic(fmt.Sprintf("unknown builtin %v", node.Name))
}
}
func (c *compiler) emitCond(body func()) {
noop := c.emit(OpJumpIfFalse, c.placeholder()...)
c.emit(OpPop)
body()
jmp := c.emit(OpJump, c.placeholder()...)
c.patchJump(noop)
c.emit(OpPop)
c.patchJump(jmp)
}
func (c *compiler) emitLoop(body func()) []byte {
i := c.makeConstant("i")
size := c.makeConstant("size")
array := c.makeConstant("array")
c.emit(OpLen)
c.emit(OpStore, size...)
c.emit(OpStore, array...)
c.emitPush(0)
c.emit(OpStore, i...)
cond := len(c.bytecode)
c.emit(OpLoad, i...)
c.emit(OpLoad, size...)
c.emit(OpLess)
end := c.emit(OpJumpIfFalse, c.placeholder()...)
c.emit(OpPop)
body()
c.emit(OpInc, i...)
c.emit(OpJumpBackward, c.calcBackwardJump(cond)...)
c.patchJump(end)
c.emit(OpPop)
return size
}
func (c *compiler) ClosureNode(node *ast.ClosureNode) {
c.compile(node.Node)
}
func (c *compiler) PointerNode(node *ast.PointerNode) {
c.emit(OpLoad, c.makeConstant("array")...)
c.emit(OpLoad, c.makeConstant("i")...)
c.emit(OpIndex)
}
func (c *compiler) ConditionalNode(node *ast.ConditionalNode) {
c.compile(node.Cond)
otherwise := c.emit(OpJumpIfFalse, c.placeholder()...)
c.emit(OpPop)
c.compile(node.Exp1)
end := c.emit(OpJump, c.placeholder()...)
c.patchJump(otherwise)
c.emit(OpPop)
c.compile(node.Exp2)
c.patchJump(end)
}
func (c *compiler) ArrayNode(node *ast.ArrayNode) {
for _, node := range node.Nodes {
c.compile(node)
}
c.emitPush(len(node.Nodes))
c.emit(OpArray)
}
func (c *compiler) MapNode(node *ast.MapNode) {
for _, pair := range node.Pairs {
c.compile(pair)
}
c.emitPush(len(node.Pairs))
c.emit(OpMap)
}
func (c *compiler) PairNode(node *ast.PairNode) {
c.compile(node.Key)
c.compile(node.Value)
}
func encode(i uint16) []byte {
b := make([]byte, 2)
binary.LittleEndian.PutUint16(b, i)
return b
}
func kind(node ast.Node) reflect.Kind {
t := node.Type()
if t == nil {
return reflect.Invalid
}
return t.Kind()
}