diff --git a/runtime/store.go b/runtime/store.go index 718b03a67..5d1244390 100644 --- a/runtime/store.go +++ b/runtime/store.go @@ -5,6 +5,8 @@ package runtime import ( "bytes" "encoding/json" + "fmt" + "log" "reflect" ) @@ -39,7 +41,7 @@ func NewStoreProvider(runtime *Runtime) *StoreProvider { // Store is where we keep named data type Store struct { name string - data interface{} + data reflect.Value dataType reflect.Type structType bool eventPrefix string @@ -54,26 +56,19 @@ type Store struct { // New creates a new store func (p *StoreProvider) New(name string, defaultValue interface{}) *Store { - result := Store{ - name: name, - runtime: p.runtime, - data: defaultValue, - } - - // initialise the store - result.init() - - return &result -} - -// NewWithOptions creates a new store with the given options -func (p *StoreProvider) NewWithOptions(options Options) *Store { + dataType := reflect.TypeOf(defaultValue) result := Store{ - name: options.Name, - notifySynchronously: options.NotifySynchronously, + name: name, + runtime: p.runtime, + data: reflect.ValueOf(defaultValue), + dataType: dataType, + structType: dataType.Kind() == reflect.Ptr, } + // Setup the sync listener + result.setupListener() + return &result } @@ -83,19 +78,6 @@ func (s *Store) OnError(callback func(error)) { s.errorHandler = callback } -// init the store -func (s *Store) init() { - - // Get the type of the data - s.dataType = reflect.TypeOf(s.data) - - // Determine if this is a struct type - s.structType = s.dataType.Kind() == reflect.Ptr - - // Setup the sync listener - s.setupListener() -} - // processUpdatedScalar will process the given scalar json func (s *Store) processUpdatedScalar(data json.RawMessage) error { @@ -108,9 +90,9 @@ func (s *Store) processUpdatedScalar(data json.RawMessage) error { // Convert to correct type if decodedVal == nil { - s.data = reflect.Zero(s.dataType).Interface() + s.data = reflect.Zero(s.dataType) } else { - s.data = reflect.ValueOf(decodedVal).Convert(s.dataType).Interface() + s.data = reflect.ValueOf(decodedVal).Convert(s.dataType) } return nil @@ -124,7 +106,7 @@ func (s *Store) processUpdatedStruct(data json.RawMessage) error { if err != nil { return err } - s.data = newData + s.data = reflect.ValueOf(newData) return nil } @@ -159,6 +141,7 @@ func (s *Store) setupListener() { if err != nil { if s.errorHandler != nil { s.errorHandler(err) + return } } @@ -184,16 +167,22 @@ func (s *Store) notify() { // Set will update the data held by the store // and notify listeners of the change -func (s *Store) Set(data interface{}) { +func (s *Store) Set(data interface{}) error { + + inType := reflect.TypeOf(data) + + if inType != s.dataType { + return fmt.Errorf("invalid data given in Store.Set(). Expected %s, got %s", s.dataType.String(), inType.String()) + } // Save data - s.data = data + s.data = reflect.ValueOf(data) // Stringify data - newdata, err := json.Marshal(s.data) + newdata, err := json.Marshal(data) if err != nil { if s.errorHandler != nil { - s.errorHandler(err) + return err } } @@ -202,6 +191,8 @@ func (s *Store) Set(data interface{}) { // Notify subscribers s.notify() + + return nil } // Subscribe will subscribe to updates to the store by @@ -211,10 +202,55 @@ func (s *Store) Subscribe(callback func(interface{})) { s.callbacks = append(s.callbacks, callback) } +func (s *Store) updaterCheck(updater interface{}) error { + + // Get type + updaterType := reflect.TypeOf(updater) + + // Check updater is a function + if updaterType.Kind() != reflect.Func { + return fmt.Errorf("invalid value given to store.Update(). Expected 'func(%s) %s'", s.dataType.String(), s.dataType.String()) + } + + // Check input param + if updaterType.NumIn() != 1 { + return fmt.Errorf("invalid number of parameters given in updater function. Expected 1") + } + + // Check input data type + if updaterType.In(0) != s.dataType { + return fmt.Errorf("invalid type for input parameter given in updater function. Expected %s, got %s", s.dataType.String(), updaterType.In(0)) + } + + // Check output param + if updaterType.NumOut() != 1 { + return fmt.Errorf("invalid number of return parameters given in updater function. Expected 1") + } + + // Check output data type + if updaterType.Out(0) != s.dataType { + return fmt.Errorf("invalid type for return parameter given in updater function. Expected %s, got %s", s.dataType.String(), updaterType.Out(0)) + } + + return nil +} + // Update takes a function that is passed the current state. // The result of that function is then set as the new state // of the store. This will notify listeners of the change -func (s *Store) Update(updater func(interface{}) interface{}) { - newData := updater(s.data) - s.Set(newData) +func (s *Store) Update(updater interface{}) { + + err := s.updaterCheck(updater) + if err != nil { + log.Fatal(err) + } + + // Build args + args := []reflect.Value{s.data} + + // Make call + results := reflect.ValueOf(updater).Call(args) + + // We will only have 1 result. Set the store to it + s.Set(results[0].Interface()) }