initial
This commit is contained in:
195
redisstore/redisstore.go
Normal file
195
redisstore/redisstore.go
Normal file
@ -0,0 +1,195 @@
|
||||
package redisstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base32"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// RedisStore stores gorilla sessions in Redis
|
||||
type RedisStore struct {
|
||||
// client to connect to redis
|
||||
client redis.UniversalClient
|
||||
// default options to use when a new session is created
|
||||
Options sessions.Options
|
||||
// key prefix with which the session will be stored
|
||||
keyPrefix string
|
||||
// key generator
|
||||
keyGen KeyGenFunc
|
||||
// session serializer
|
||||
serializer SessionSerializer
|
||||
}
|
||||
|
||||
// KeyGenFunc defines a function used by store to generate a key
|
||||
type KeyGenFunc func() (string, error)
|
||||
|
||||
// NewRedisStore returns a new RedisStore with default configuration
|
||||
func NewRedisStore(ctx context.Context, client redis.UniversalClient) (*RedisStore, error) {
|
||||
|
||||
rs := &RedisStore{
|
||||
Options: sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: 86400 * 30,
|
||||
},
|
||||
client: client,
|
||||
keyPrefix: "session:",
|
||||
keyGen: generateRandomKey,
|
||||
serializer: GobSerializer{},
|
||||
}
|
||||
|
||||
return rs, rs.client.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
// Get returns a session for the given name after adding it to the registry.
|
||||
func (s *RedisStore) Get(r *http.Request, name string) (*sessions.Session, error) {
|
||||
return sessions.GetRegistry(r).Get(s, name)
|
||||
}
|
||||
|
||||
// New returns a session for the given name without adding it to the registry.
|
||||
func (s *RedisStore) New(r *http.Request, name string) (*sessions.Session, error) {
|
||||
|
||||
session := sessions.NewSession(s, name)
|
||||
opts := s.Options
|
||||
session.Options = &opts
|
||||
session.IsNew = true
|
||||
|
||||
c, err := r.Cookie(name)
|
||||
if err != nil {
|
||||
return session, nil
|
||||
}
|
||||
session.ID = c.Value
|
||||
|
||||
err = s.load(r.Context(), session)
|
||||
if err == nil {
|
||||
session.IsNew = false
|
||||
} else if err == redis.Nil {
|
||||
err = nil // no data stored
|
||||
}
|
||||
return session, err
|
||||
}
|
||||
|
||||
// Save adds a single session to the response.
|
||||
//
|
||||
// If the Options.MaxAge of the session is <= 0 then the session file will be
|
||||
// deleted from the store. With this process it enforces the properly
|
||||
// session cookie handling so no need to trust in the cookie management in the
|
||||
// web browser.
|
||||
func (s *RedisStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
|
||||
// Delete if max-age is <= 0
|
||||
if session.Options.MaxAge <= 0 {
|
||||
if err := s.delete(r.Context(), session); err != nil {
|
||||
return err
|
||||
}
|
||||
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
|
||||
return nil
|
||||
}
|
||||
|
||||
if session.ID == "" {
|
||||
id, err := s.keyGen()
|
||||
if err != nil {
|
||||
return errors.New("redisstore: failed to generate session id")
|
||||
}
|
||||
session.ID = id
|
||||
}
|
||||
if err := s.save(r.Context(), session); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
http.SetCookie(w, sessions.NewCookie(session.Name(), session.ID, session.Options))
|
||||
return nil
|
||||
}
|
||||
|
||||
// KeyPrefix sets the key prefix to store session in Redis
|
||||
func (s *RedisStore) KeyPrefix(keyPrefix string) {
|
||||
s.keyPrefix = keyPrefix
|
||||
}
|
||||
|
||||
// KeyGen sets the key generator function
|
||||
func (s *RedisStore) KeyGen(f KeyGenFunc) {
|
||||
s.keyGen = f
|
||||
}
|
||||
|
||||
// Serializer sets the session serializer to store session
|
||||
func (s *RedisStore) Serializer(ss SessionSerializer) {
|
||||
s.serializer = ss
|
||||
}
|
||||
|
||||
// Close closes the Redis store
|
||||
func (s *RedisStore) Close() error {
|
||||
return s.client.Close()
|
||||
}
|
||||
|
||||
// save writes session in Redis
|
||||
func (s *RedisStore) save(ctx context.Context, session *sessions.Session) error {
|
||||
|
||||
b, err := s.serializer.Serialize(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.client.Set(ctx, s.keyPrefix+session.ID, b, time.Duration(session.Options.MaxAge)*time.Second).Err()
|
||||
}
|
||||
|
||||
// load reads session from Redis
|
||||
func (s *RedisStore) load(ctx context.Context, session *sessions.Session) error {
|
||||
|
||||
cmd := s.client.Get(ctx, s.keyPrefix+session.ID)
|
||||
if cmd.Err() != nil {
|
||||
return cmd.Err()
|
||||
}
|
||||
|
||||
b, err := cmd.Bytes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.serializer.Deserialize(b, session)
|
||||
}
|
||||
|
||||
// delete deletes session in Redis
|
||||
func (s *RedisStore) delete(ctx context.Context, session *sessions.Session) error {
|
||||
return s.client.Del(ctx, s.keyPrefix+session.ID).Err()
|
||||
}
|
||||
|
||||
// SessionSerializer provides an interface for serialize/deserialize a session
|
||||
type SessionSerializer interface {
|
||||
Serialize(s *sessions.Session) ([]byte, error)
|
||||
Deserialize(b []byte, s *sessions.Session) error
|
||||
}
|
||||
|
||||
// Gob serializer
|
||||
type GobSerializer struct{}
|
||||
|
||||
func (gs GobSerializer) Serialize(s *sessions.Session) ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
enc := gob.NewEncoder(buf)
|
||||
err := enc.Encode(s.Values)
|
||||
if err == nil {
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (gs GobSerializer) Deserialize(d []byte, s *sessions.Session) error {
|
||||
dec := gob.NewDecoder(bytes.NewBuffer(d))
|
||||
return dec.Decode(&s.Values)
|
||||
}
|
||||
|
||||
// generateRandomKey returns a new random key
|
||||
func generateRandomKey() (string, error) {
|
||||
k := make([]byte, 64)
|
||||
if _, err := io.ReadFull(rand.Reader, k); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return strings.TrimRight(base32.StdEncoding.EncodeToString(k), "="), nil
|
||||
}
|
Reference in New Issue
Block a user