woodpecker-email/vendor/github.com/antonmedv/expr/vm/vm.go
2023-01-04 13:11:21 +01:00

485 lines
9 KiB
Go

package vm
import (
"fmt"
"reflect"
"regexp"
"strings"
"github.com/antonmedv/expr/file"
)
var errorType = reflect.TypeOf((*error)(nil)).Elem()
var (
MemoryBudget int = 1e6
)
func Run(program *Program, env interface{}) (interface{}, error) {
if program == nil {
return nil, fmt.Errorf("program is nil")
}
vm := VM{}
return vm.Run(program, env)
}
type VM struct {
stack []interface{}
constants []interface{}
bytecode []byte
ip int
pp int
scopes []Scope
debug bool
step chan struct{}
curr chan int
memory int
limit int
}
func Debug() *VM {
vm := &VM{
debug: true,
step: make(chan struct{}, 0),
curr: make(chan int, 0),
}
return vm
}
func (vm *VM) Run(program *Program, env interface{}) (out interface{}, err error) {
defer func() {
if r := recover(); r != nil {
f := &file.Error{
Location: program.Locations[vm.pp],
Message: fmt.Sprintf("%v", r),
}
err = f.Bind(program.Source)
}
}()
vm.limit = MemoryBudget
vm.ip = 0
vm.pp = 0
if vm.stack == nil {
vm.stack = make([]interface{}, 0, 2)
} else {
vm.stack = vm.stack[0:0]
}
if vm.scopes != nil {
vm.scopes = vm.scopes[0:0]
}
vm.bytecode = program.Bytecode
vm.constants = program.Constants
for vm.ip < len(vm.bytecode) {
if vm.debug {
<-vm.step
}
vm.pp = vm.ip
vm.ip++
op := vm.bytecode[vm.pp]
switch op {
case OpPush:
vm.push(vm.constant())
case OpPop:
vm.pop()
case OpRot:
b := vm.pop()
a := vm.pop()
vm.push(b)
vm.push(a)
case OpFetch:
vm.push(fetch(env, vm.constant(), false))
case OpFetchNilSafe:
vm.push(fetch(env, vm.constant(), true))
case OpFetchMap:
vm.push(env.(map[string]interface{})[vm.constant().(string)])
case OpTrue:
vm.push(true)
case OpFalse:
vm.push(false)
case OpNil:
vm.push(nil)
case OpNegate:
v := negate(vm.pop())
vm.push(v)
case OpNot:
v := vm.pop().(bool)
vm.push(!v)
case OpEqual:
b := vm.pop()
a := vm.pop()
vm.push(equal(a, b))
case OpEqualInt:
b := vm.pop()
a := vm.pop()
vm.push(a.(int) == b.(int))
case OpEqualString:
b := vm.pop()
a := vm.pop()
vm.push(a.(string) == b.(string))
case OpJump:
offset := vm.arg()
vm.ip += int(offset)
case OpJumpIfTrue:
offset := vm.arg()
if vm.current().(bool) {
vm.ip += int(offset)
}
case OpJumpIfFalse:
offset := vm.arg()
if !vm.current().(bool) {
vm.ip += int(offset)
}
case OpJumpBackward:
offset := vm.arg()
vm.ip -= int(offset)
case OpIn:
b := vm.pop()
a := vm.pop()
vm.push(in(a, b))
case OpLess:
b := vm.pop()
a := vm.pop()
vm.push(less(a, b))
case OpMore:
b := vm.pop()
a := vm.pop()
vm.push(more(a, b))
case OpLessOrEqual:
b := vm.pop()
a := vm.pop()
vm.push(lessOrEqual(a, b))
case OpMoreOrEqual:
b := vm.pop()
a := vm.pop()
vm.push(moreOrEqual(a, b))
case OpAdd:
b := vm.pop()
a := vm.pop()
vm.push(add(a, b))
case OpSubtract:
b := vm.pop()
a := vm.pop()
vm.push(subtract(a, b))
case OpMultiply:
b := vm.pop()
a := vm.pop()
vm.push(multiply(a, b))
case OpDivide:
b := vm.pop()
a := vm.pop()
vm.push(divide(a, b))
case OpModulo:
b := vm.pop()
a := vm.pop()
vm.push(modulo(a, b))
case OpExponent:
b := vm.pop()
a := vm.pop()
vm.push(exponent(a, b))
case OpRange:
b := vm.pop()
a := vm.pop()
min := toInt(a)
max := toInt(b)
size := max - min + 1
if vm.memory+size >= vm.limit {
panic("memory budget exceeded")
}
vm.push(makeRange(min, max))
vm.memory += size
case OpMatches:
b := vm.pop()
a := vm.pop()
match, err := regexp.MatchString(b.(string), a.(string))
if err != nil {
panic(err)
}
vm.push(match)
case OpMatchesConst:
a := vm.pop()
r := vm.constant().(*regexp.Regexp)
vm.push(r.MatchString(a.(string)))
case OpContains:
b := vm.pop()
a := vm.pop()
vm.push(strings.Contains(a.(string), b.(string)))
case OpStartsWith:
b := vm.pop()
a := vm.pop()
vm.push(strings.HasPrefix(a.(string), b.(string)))
case OpEndsWith:
b := vm.pop()
a := vm.pop()
vm.push(strings.HasSuffix(a.(string), b.(string)))
case OpIndex:
b := vm.pop()
a := vm.pop()
vm.push(fetch(a, b, false))
case OpSlice:
from := vm.pop()
to := vm.pop()
node := vm.pop()
vm.push(slice(node, from, to))
case OpProperty:
a := vm.pop()
b := vm.constant()
vm.push(fetch(a, b, false))
case OpPropertyNilSafe:
a := vm.pop()
b := vm.constant()
vm.push(fetch(a, b, true))
case OpCall:
call := vm.constant().(Call)
in := make([]reflect.Value, call.Size)
for i := call.Size - 1; i >= 0; i-- {
param := vm.pop()
if param == nil && reflect.TypeOf(param) == nil {
// In case of nil value and nil type use this hack,
// otherwise reflect.Call will panic on zero value.
in[i] = reflect.ValueOf(&param).Elem()
} else {
in[i] = reflect.ValueOf(param)
}
}
out := FetchFn(env, call.Name).Call(in)
if len(out) == 2 && out[1].Type() == errorType && !out[1].IsNil() {
return nil, out[1].Interface().(error)
}
vm.push(out[0].Interface())
case OpCallFast:
call := vm.constant().(Call)
in := make([]interface{}, call.Size)
for i := call.Size - 1; i >= 0; i-- {
in[i] = vm.pop()
}
fn := FetchFn(env, call.Name).Interface()
if typed, ok := fn.(func(...interface{}) interface{}); ok {
vm.push(typed(in...))
} else if typed, ok := fn.(func(...interface{}) (interface{}, error)); ok {
res, err := typed(in...)
if err != nil {
return nil, err
}
vm.push(res)
}
case OpMethod:
call := vm.constants[vm.arg()].(Call)
in := make([]reflect.Value, call.Size)
for i := call.Size - 1; i >= 0; i-- {
param := vm.pop()
if param == nil && reflect.TypeOf(param) == nil {
// In case of nil value and nil type use this hack,
// otherwise reflect.Call will panic on zero value.
in[i] = reflect.ValueOf(&param).Elem()
} else {
in[i] = reflect.ValueOf(param)
}
}
out := FetchFn(vm.pop(), call.Name).Call(in)
if len(out) == 2 && out[1].Type() == errorType && !out[1].IsNil() {
return nil, out[1].Interface().(error)
}
vm.push(out[0].Interface())
case OpMethodNilSafe:
call := vm.constants[vm.arg()].(Call)
in := make([]reflect.Value, call.Size)
for i := call.Size - 1; i >= 0; i-- {
param := vm.pop()
if param == nil && reflect.TypeOf(param) == nil {
// In case of nil value and nil type use this hack,
// otherwise reflect.Call will panic on zero value.
in[i] = reflect.ValueOf(&param).Elem()
} else {
in[i] = reflect.ValueOf(param)
}
}
fn := FetchFnNil(vm.pop(), call.Name)
if !fn.IsValid() {
vm.push(nil)
} else {
out := fn.Call(in)
vm.push(out[0].Interface())
}
case OpArray:
size := vm.pop().(int)
array := make([]interface{}, size)
for i := size - 1; i >= 0; i-- {
array[i] = vm.pop()
}
vm.push(array)
vm.memory += size
if vm.memory >= vm.limit {
panic("memory budget exceeded")
}
case OpMap:
size := vm.pop().(int)
m := make(map[string]interface{})
for i := size - 1; i >= 0; i-- {
value := vm.pop()
key := vm.pop()
m[key.(string)] = value
}
vm.push(m)
vm.memory += size
if vm.memory >= vm.limit {
panic("memory budget exceeded")
}
case OpLen:
vm.push(length(vm.current()))
case OpCast:
t := vm.arg()
switch t {
case 0:
vm.push(toInt64(vm.pop()))
case 1:
vm.push(toFloat64(vm.pop()))
}
case OpStore:
scope := vm.Scope()
key := vm.constant().(string)
value := vm.pop()
scope[key] = value
case OpLoad:
scope := vm.Scope()
key := vm.constant().(string)
vm.push(scope[key])
case OpInc:
scope := vm.Scope()
key := vm.constant().(string)
i := scope[key].(int)
i++
scope[key] = i
case OpBegin:
scope := make(Scope)
vm.scopes = append(vm.scopes, scope)
case OpEnd:
vm.scopes = vm.scopes[:len(vm.scopes)-1]
default:
panic(fmt.Sprintf("unknown bytecode %#x", op))
}
if vm.debug {
vm.curr <- vm.ip
}
}
if vm.debug {
close(vm.curr)
close(vm.step)
}
if len(vm.stack) > 0 {
return vm.pop(), nil
}
return nil, nil
}
func (vm *VM) push(value interface{}) {
vm.stack = append(vm.stack, value)
}
func (vm *VM) current() interface{} {
return vm.stack[len(vm.stack)-1]
}
func (vm *VM) pop() interface{} {
value := vm.stack[len(vm.stack)-1]
vm.stack = vm.stack[:len(vm.stack)-1]
return value
}
func (vm *VM) arg() uint16 {
b0, b1 := vm.bytecode[vm.ip], vm.bytecode[vm.ip+1]
vm.ip += 2
return uint16(b0) | uint16(b1)<<8
}
func (vm *VM) constant() interface{} {
return vm.constants[vm.arg()]
}
func (vm *VM) Stack() []interface{} {
return vm.stack
}
func (vm *VM) Scope() Scope {
if len(vm.scopes) > 0 {
return vm.scopes[len(vm.scopes)-1]
}
return nil
}
func (vm *VM) Step() {
if vm.ip < len(vm.bytecode) {
vm.step <- struct{}{}
}
}
func (vm *VM) Position() chan int {
return vm.curr
}