package mongostore import ( "context" "errors" "fmt" "net/http" "time" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" ) type MongoStore struct { client *mongo.Client Codecs []securecookie.Codec Options *sessions.Options dbDatabaseName string dbCollectionName string dbTimeout time.Duration } func NewMongoStore(client *mongo.Client, dbDatabase, dbCollection string, dbTimeout time.Duration, keyPairs ...[]byte) *MongoStore { s := new(MongoStore) s.client = client s.dbDatabaseName = dbDatabase s.dbCollectionName = dbCollection s.dbTimeout = dbTimeout s.Options = &sessions.Options{ Path: "/", MaxAge: 86400 * 30 * 365, } s.Codecs = securecookie.CodecsFromPairs(keyPairs...) return s } func (s *MongoStore) MaxAge(age int) { s.Options.MaxAge = age for _, codec := range s.Codecs { if sc, ok := codec.(*securecookie.SecureCookie); ok { sc.MaxAge(age) } } } func (s *MongoStore) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(s, name) session.Options = &sessions.Options{ Path: s.Options.Path, MaxAge: s.Options.MaxAge, Domain: s.Options.Domain, Secure: s.Options.Secure, HttpOnly: s.Options.HttpOnly, } session.IsNew = true cookie, e := r.Cookie(name) if e != nil { return session, e } e = securecookie.DecodeMulti(name, cookie.Value, &session.ID, s.Codecs...) if e != nil { return session, e } e = s.load(session) if e != nil { return session, e } else { session.IsNew = false } return session, nil } func (s *MongoStore) Get(r *http.Request, name string) (*sessions.Session, error) { return sessions.GetRegistry(r).Get(s, name) } func (s *MongoStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { if session.Options.MaxAge < 0 { if e := s.remove(session); e != nil { return e } http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) } else { if session.ID == "" { session.ID = primitive.NewObjectID().Hex() } if e := s.save(session); e != nil { return e } encoded, e := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...) if e != nil { return e } http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options)) } return nil } func (s *MongoStore) Close() error { ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout) return s.client.Disconnect(ctx) } func (s *MongoStore) load(session *sessions.Session) error { oid, e := primitive.ObjectIDFromHex(session.ID) if e != nil { return e } collection := s.client.Database(s.dbDatabaseName).Collection(s.dbCollectionName) ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout) result := collection.FindOne(ctx, bson.M{"_id": oid}) ns := new(Session) if e := result.Decode(ns); e != nil { return e } if e := securecookie.DecodeMulti(session.Name(), ns.Data, &session.Values, s.Codecs...); e != nil { return e } return nil } func (s *MongoStore) save(session *sessions.Session) error { oid, e := primitive.ObjectIDFromHex(session.ID) if e != nil { return e } collection := s.client.Database(s.dbDatabaseName).Collection(s.dbCollectionName) ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout) var modified time.Time if val, ok := session.Values["modified"]; ok { modified, ok = val.(time.Time) if !ok { return errors.New("mongostore: invalid modified value") } } else { modified = time.Now() } encoded, e := securecookie.EncodeMulti(session.Name(), session.Values, s.Codecs...) if e != nil { return e } ns := new(Session) ns.ID = oid ns.Data = encoded ns.Modified = modified if session.IsNew { if _, e := collection.InsertOne(ctx, ns); e != nil { return e } } else { if _, e := collection.ReplaceOne(ctx, bson.M{"_id": oid}, ns); e != nil { return e } } return nil } func (s *MongoStore) remove(session *sessions.Session) error { oid, e := primitive.ObjectIDFromHex(session.ID) if e != nil { return e } collection := s.client.Database(s.dbDatabaseName).Collection(s.dbCollectionName) ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout) dr, e := collection.DeleteOne(ctx, bson.M{"_id": oid}) if e != nil { return e } if dr.DeletedCount != 1 { return errors.New(fmt.Sprintf("DeletedCount (%d) != 1", dr.DeletedCount)) } return nil }