90 lines
2.2 KiB
Go
90 lines
2.2 KiB
Go
package mongostore
|
|
|
|
import (
|
|
"context"
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/gorilla/securecookie"
|
|
"github.com/gorilla/sessions"
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
)
|
|
|
|
type MongoStore struct {
|
|
client *mongo.Client
|
|
Codecs []securecookie.Codec
|
|
Options *sessions.Options
|
|
|
|
dbDatabaseName string
|
|
dbCollectionName string
|
|
dbTimeout time.Duration
|
|
}
|
|
|
|
func NewMongoStore(client *mongo.Client, dbDatabase, dbCollection string, dbTimeout time.Duration) *MongoStore {
|
|
s := new(MongoStore)
|
|
s.client = client
|
|
s.dbDatabaseName = dbDatabase
|
|
s.dbCollectionName = dbCollection
|
|
s.dbTimeout = dbTimeout
|
|
return s
|
|
}
|
|
|
|
func (s *MongoStore) New(r *http.Request, name string) (*sessions.Session, error) {
|
|
session := sessions.NewSession(s, name)
|
|
options := *s.Options
|
|
session.Options = &options
|
|
session.IsNew = true
|
|
cookie, e := r.Cookie(name)
|
|
if e != nil {
|
|
return nil, e
|
|
}
|
|
e = securecookie.DecodeMulti(name, cookie.Value, &session.ID, s.Codecs...)
|
|
if e != nil {
|
|
return nil, e
|
|
}
|
|
e = s.load(session)
|
|
if e != nil {
|
|
return nil, e
|
|
} else {
|
|
session.IsNew = false
|
|
}
|
|
return session, nil
|
|
}
|
|
|
|
func (s *MongoStore) Get(r *http.Request, name string) (*sessions.Session, error) {
|
|
return sessions.GetRegistry(r).Get(s, name)
|
|
}
|
|
|
|
func (s *MongoStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
|
|
panic("not implemented")
|
|
}
|
|
|
|
func (s *MongoStore) Delete(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
|
|
panic("not implemented")
|
|
}
|
|
|
|
func (s *MongoStore) Close() error {
|
|
ctx, _ := context.WithTimeout(context.Background(), 3*time.Second)
|
|
return s.client.Disconnect(ctx)
|
|
}
|
|
|
|
func (s *MongoStore) load(session *sessions.Session) error {
|
|
oid, e := primitive.ObjectIDFromHex(session.ID)
|
|
if e != nil {
|
|
return e
|
|
}
|
|
collection := s.client.Database(s.dbDatabaseName).Collection(s.dbCollectionName)
|
|
ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout)
|
|
result := collection.FindOne(ctx, bson.M{"_id": oid})
|
|
ns := new(Session)
|
|
if e := result.Decode(ns); e != nil {
|
|
return e
|
|
}
|
|
if e := securecookie.DecodeMulti(session.Name(), string(ns.Data), &session.Values, s.Codecs...); e != nil {
|
|
return e
|
|
}
|
|
return nil
|
|
}
|