diff --git a/proxy.conf.in b/proxy.conf.in index 175cbf0..98b0fde 100644 --- a/proxy.conf.in +++ b/proxy.conf.in @@ -59,7 +59,8 @@ blockkey = -encryption-key- #cacert = /path/to/etcd-ca.crt # For token type "etcd": Format of key name to retrieve the public key from, -# "%s" will be replaced with the token id. +# "%s" will be replaced with the token id. Multiple possible formats can be +# comma-separated. #keyformat = /signaling/proxy/tokens/%s/public-key [mcu] diff --git a/src/proxy/proxy_tokens_etcd.go b/src/proxy/proxy_tokens_etcd.go index cd1469b..0c1c309 100644 --- a/src/proxy/proxy_tokens_etcd.go +++ b/src/proxy/proxy_tokens_etcd.go @@ -53,8 +53,8 @@ type tokenCacheEntry struct { type tokensEtcd struct { client atomic.Value - tokenFormat atomic.Value - tokenCache *signaling.LruCache + tokenFormats atomic.Value + tokenCache *signaling.LruCache } func NewProxyTokensEtcd(config *goconf.ConfigFile) (ProxyTokens, error) { @@ -77,15 +77,20 @@ func (t *tokensEtcd) getClient() *clientv3.Client { return c.(*clientv3.Client) } -func (t *tokensEtcd) getKey(id string) string { - format := t.tokenFormat.Load().(string) - return fmt.Sprintf(format, id) +func (t *tokensEtcd) getKeys(id string) []string { + format := t.tokenFormats.Load().([]string) + var result []string + for _, f := range format { + result = append(result, fmt.Sprintf(f, id)) + } + return result } -func (t *tokensEtcd) Get(id string) (*ProxyToken, error) { +func (t *tokensEtcd) getByKey(id string, key string) (*ProxyToken, error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - resp, err := t.getClient().Get(ctx, t.getKey(id)) + + resp, err := t.getClient().Get(ctx, key) if err != nil { return nil, err } @@ -93,31 +98,47 @@ func (t *tokensEtcd) Get(id string) (*ProxyToken, error) { if len(resp.Kvs) == 0 { return nil, nil } else if len(resp.Kvs) > 1 { - log.Printf("Received multiple keys for %s, using last", id) + log.Printf("Received multiple keys for %s, using last", key) } keyValue := resp.Kvs[len(resp.Kvs)-1].Value - cached, _ := t.tokenCache.Get(id).(*tokenCacheEntry) + cached, _ := t.tokenCache.Get(key).(*tokenCacheEntry) if cached == nil || !bytes.Equal(cached.keyValue, keyValue) { // Parsed public keys are cached to avoid the parse overhead. - key, err := jwt.ParseRSAPublicKeyFromPEM(keyValue) + publicKey, err := jwt.ParseRSAPublicKeyFromPEM(keyValue) if err != nil { - return nil, fmt.Errorf("Could not parse public key for %s: %s", id, err) + return nil, err } cached = &tokenCacheEntry{ keyValue: keyValue, token: &ProxyToken{ id: id, - key: key, + key: publicKey, }, } - t.tokenCache.Set(id, cached) + t.tokenCache.Set(key, cached) } return cached.token, nil } +func (t *tokensEtcd) Get(id string) (*ProxyToken, error) { + for _, k := range t.getKeys(id) { + token, err := t.getByKey(id, k) + if err != nil { + log.Printf("Could not get public key from %s for %s: %s", k, id, err) + continue + } else if token == nil { + continue + } + + return token, nil + } + + return nil, nil +} + func (t *tokensEtcd) load(config *goconf.ConfigFile, ignoreErrors bool) error { var endpoints []string if endpointsString, _ := config.GetString("tokens", "endpoints"); endpointsString != "" { @@ -196,8 +217,17 @@ func (t *tokensEtcd) load(config *goconf.ConfigFile, ignoreErrors bool) error { tokenFormat = "/%s" } - t.tokenFormat.Store(tokenFormat) - log.Printf("Using %s as token format", tokenFormat) + formats := strings.Split(tokenFormat, ",") + var tokenFormats []string + for _, f := range formats { + f = strings.TrimSpace(f) + if f != "" { + tokenFormats = append(tokenFormats, f) + } + } + + t.tokenFormats.Store(tokenFormats) + log.Printf("Using %v as token formats", tokenFormats) return nil }