2020-03-22 01:01:55 +01:00
|
|
|
package mongostore
|
|
|
|
|
|
|
|
import (
|
|
|
|
"context"
|
2020-04-02 19:16:51 +02:00
|
|
|
"errors"
|
|
|
|
"fmt"
|
2020-03-22 01:01:55 +01:00
|
|
|
"net/http"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/gorilla/securecookie"
|
|
|
|
"github.com/gorilla/sessions"
|
2020-04-02 20:02:33 +02:00
|
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
|
|
"go.mongodb.org/mongo-driver/bson/primitive"
|
2020-03-22 01:01:55 +01:00
|
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
|
|
)
|
|
|
|
|
|
|
|
type MongoStore struct {
|
|
|
|
client *mongo.Client
|
|
|
|
Codecs []securecookie.Codec
|
|
|
|
Options *sessions.Options
|
2020-04-02 18:42:27 +02:00
|
|
|
|
|
|
|
dbDatabaseName string
|
|
|
|
dbCollectionName string
|
|
|
|
dbTimeout time.Duration
|
2020-03-22 01:01:55 +01:00
|
|
|
}
|
|
|
|
|
2020-04-02 20:02:33 +02:00
|
|
|
func NewMongoStore(client *mongo.Client, dbDatabase, dbCollection string, dbTimeout time.Duration, keyPairs ...[]byte) *MongoStore {
|
2020-03-22 01:01:55 +01:00
|
|
|
s := new(MongoStore)
|
|
|
|
s.client = client
|
2020-04-02 18:42:27 +02:00
|
|
|
s.dbDatabaseName = dbDatabase
|
|
|
|
s.dbCollectionName = dbCollection
|
|
|
|
s.dbTimeout = dbTimeout
|
2020-04-02 20:02:33 +02:00
|
|
|
s.Options = &sessions.Options{
|
|
|
|
Path: "/",
|
|
|
|
MaxAge: 86400 * 30 * 365,
|
|
|
|
}
|
|
|
|
s.Codecs = securecookie.CodecsFromPairs(keyPairs...)
|
2020-03-22 01:01:55 +01:00
|
|
|
return s
|
|
|
|
}
|
|
|
|
|
2020-04-02 20:02:33 +02:00
|
|
|
func (s *MongoStore) MaxAge(age int) {
|
|
|
|
s.Options.MaxAge = age
|
|
|
|
|
|
|
|
for _, codec := range s.Codecs {
|
|
|
|
if sc, ok := codec.(*securecookie.SecureCookie); ok {
|
|
|
|
sc.MaxAge(age)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2020-03-22 01:01:55 +01:00
|
|
|
func (s *MongoStore) New(r *http.Request, name string) (*sessions.Session, error) {
|
2020-04-02 18:42:27 +02:00
|
|
|
session := sessions.NewSession(s, name)
|
2020-04-02 20:02:33 +02:00
|
|
|
session.Options = &sessions.Options{
|
|
|
|
Path: s.Options.Path,
|
|
|
|
MaxAge: s.Options.MaxAge,
|
|
|
|
Domain: s.Options.Domain,
|
|
|
|
Secure: s.Options.Secure,
|
|
|
|
HttpOnly: s.Options.HttpOnly,
|
|
|
|
}
|
2020-04-02 18:42:27 +02:00
|
|
|
session.IsNew = true
|
|
|
|
cookie, e := r.Cookie(name)
|
|
|
|
if e != nil {
|
2020-04-02 20:02:33 +02:00
|
|
|
return session, e
|
2020-04-02 18:42:27 +02:00
|
|
|
}
|
|
|
|
e = securecookie.DecodeMulti(name, cookie.Value, &session.ID, s.Codecs...)
|
|
|
|
if e != nil {
|
2020-04-02 20:02:33 +02:00
|
|
|
return session, e
|
2020-04-02 18:42:27 +02:00
|
|
|
}
|
|
|
|
e = s.load(session)
|
|
|
|
if e != nil {
|
2020-04-02 20:02:33 +02:00
|
|
|
return session, e
|
2020-04-02 18:42:27 +02:00
|
|
|
} else {
|
|
|
|
session.IsNew = false
|
|
|
|
}
|
|
|
|
return session, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *MongoStore) Get(r *http.Request, name string) (*sessions.Session, error) {
|
|
|
|
return sessions.GetRegistry(r).Get(s, name)
|
2020-03-22 01:01:55 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (s *MongoStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
|
2020-04-02 19:16:51 +02:00
|
|
|
if session.Options.MaxAge < 0 {
|
|
|
|
if e := s.remove(session); e != nil {
|
|
|
|
return e
|
|
|
|
}
|
|
|
|
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
|
|
|
|
} else {
|
|
|
|
if session.ID == "" {
|
|
|
|
session.ID = primitive.NewObjectID().Hex()
|
|
|
|
}
|
|
|
|
if e := s.save(session); e != nil {
|
|
|
|
return e
|
|
|
|
}
|
|
|
|
encoded, e := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...)
|
|
|
|
if e != nil {
|
|
|
|
return e
|
|
|
|
}
|
|
|
|
http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options))
|
|
|
|
}
|
|
|
|
return nil
|
2020-03-22 01:01:55 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func (s *MongoStore) Close() error {
|
2020-04-02 19:16:51 +02:00
|
|
|
ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout)
|
2020-03-22 01:01:55 +01:00
|
|
|
return s.client.Disconnect(ctx)
|
|
|
|
}
|
2020-04-02 18:42:27 +02:00
|
|
|
|
|
|
|
func (s *MongoStore) load(session *sessions.Session) error {
|
|
|
|
oid, e := primitive.ObjectIDFromHex(session.ID)
|
|
|
|
if e != nil {
|
|
|
|
return e
|
|
|
|
}
|
|
|
|
collection := s.client.Database(s.dbDatabaseName).Collection(s.dbCollectionName)
|
|
|
|
ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout)
|
|
|
|
result := collection.FindOne(ctx, bson.M{"_id": oid})
|
|
|
|
ns := new(Session)
|
|
|
|
if e := result.Decode(ns); e != nil {
|
|
|
|
return e
|
|
|
|
}
|
2020-04-02 19:16:51 +02:00
|
|
|
if e := securecookie.DecodeMulti(session.Name(), ns.Data, &session.Values, s.Codecs...); e != nil {
|
2020-04-02 18:42:27 +02:00
|
|
|
return e
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
2020-04-02 19:16:51 +02:00
|
|
|
|
|
|
|
func (s *MongoStore) save(session *sessions.Session) error {
|
|
|
|
oid, e := primitive.ObjectIDFromHex(session.ID)
|
|
|
|
if e != nil {
|
|
|
|
return e
|
|
|
|
}
|
|
|
|
collection := s.client.Database(s.dbDatabaseName).Collection(s.dbCollectionName)
|
|
|
|
ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout)
|
|
|
|
|
|
|
|
var modified time.Time
|
|
|
|
if val, ok := session.Values["modified"]; ok {
|
|
|
|
modified, ok = val.(time.Time)
|
|
|
|
if !ok {
|
|
|
|
return errors.New("mongostore: invalid modified value")
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
modified = time.Now()
|
|
|
|
}
|
|
|
|
encoded, e := securecookie.EncodeMulti(session.Name(), session.Values, s.Codecs...)
|
|
|
|
if e != nil {
|
|
|
|
return e
|
|
|
|
}
|
|
|
|
ns := new(Session)
|
|
|
|
ns.ID = oid
|
|
|
|
ns.Data = encoded
|
|
|
|
ns.Modified = modified
|
2020-04-02 20:02:33 +02:00
|
|
|
if session.IsNew {
|
|
|
|
if _, e := collection.InsertOne(ctx, ns); e != nil {
|
|
|
|
return e
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
if _, e := collection.UpdateOne(ctx, bson.M{"_id": oid}, ns); e != nil {
|
|
|
|
return e
|
|
|
|
}
|
2020-04-02 19:16:51 +02:00
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (s *MongoStore) remove(session *sessions.Session) error {
|
|
|
|
oid, e := primitive.ObjectIDFromHex(session.ID)
|
|
|
|
if e != nil {
|
|
|
|
return e
|
|
|
|
}
|
|
|
|
collection := s.client.Database(s.dbDatabaseName).Collection(s.dbCollectionName)
|
|
|
|
ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout)
|
|
|
|
dr, e := collection.DeleteOne(ctx, bson.M{"_id": oid})
|
|
|
|
if e != nil {
|
|
|
|
return e
|
|
|
|
}
|
|
|
|
if dr.DeletedCount != 1 {
|
|
|
|
return errors.New(fmt.Sprintf("DeletedCount (%d) != 1", dr.DeletedCount))
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|