diff --git a/store.go b/store.go index 876ccde..ba2cfc4 100644 --- a/store.go +++ b/store.go @@ -2,6 +2,8 @@ package mongostore import ( "context" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" "net/http" "time" @@ -14,16 +16,45 @@ type MongoStore struct { client *mongo.Client Codecs []securecookie.Codec Options *sessions.Options + + dbDatabaseName string + dbCollectionName string + dbTimeout time.Duration } -func NewMongoStore(client *mongo.Client) *MongoStore { +func NewMongoStore(client *mongo.Client, dbDatabase, dbCollection string, dbTimeout time.Duration) *MongoStore { s := new(MongoStore) s.client = client + s.dbDatabaseName = dbDatabase + s.dbCollectionName = dbCollection + s.dbTimeout = dbTimeout return s } func (s *MongoStore) New(r *http.Request, name string) (*sessions.Session, error) { - panic("not implemented") + session := sessions.NewSession(s, name) + options := *s.Options + session.Options = &options + session.IsNew = true + cookie, e := r.Cookie(name) + if e != nil { + return nil, e + } + e = securecookie.DecodeMulti(name, cookie.Value, &session.ID, s.Codecs...) + if e != nil { + return nil, e + } + e = s.load(session) + if e != nil { + return nil, e + } else { + session.IsNew = false + } + return session, nil +} + +func (s *MongoStore) Get(r *http.Request, name string) (*sessions.Session, error) { + return sessions.GetRegistry(r).Get(s, name) } func (s *MongoStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { @@ -38,3 +69,21 @@ func (s *MongoStore) Close() error { ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) return s.client.Disconnect(ctx) } + +func (s *MongoStore) load(session *sessions.Session) error { + oid, e := primitive.ObjectIDFromHex(session.ID) + if e != nil { + return e + } + collection := s.client.Database(s.dbDatabaseName).Collection(s.dbCollectionName) + ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout) + result := collection.FindOne(ctx, bson.M{"_id": oid}) + ns := new(Session) + if e := result.Decode(ns); e != nil { + return e + } + if e := securecookie.DecodeMulti(session.Name(), string(ns.Data), &session.Values, s.Codecs...); e != nil { + return e + } + return nil +}