diff --git a/model.go b/model.go index e9bb037..1a34f76 100644 --- a/model.go +++ b/model.go @@ -8,6 +8,6 @@ import ( type Session struct { ID primitive.ObjectID `bson:"_id,omitempty"` - Data []byte `bson:"data"` + Data string `bson:"data"` Modified time.Time `bson:"modified"` } diff --git a/store.go b/store.go index ba2cfc4..8ad8024 100644 --- a/store.go +++ b/store.go @@ -2,6 +2,8 @@ package mongostore import ( "context" + "errors" + "fmt" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" "net/http" @@ -58,15 +60,29 @@ func (s *MongoStore) Get(r *http.Request, name string) (*sessions.Session, error } func (s *MongoStore) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { - panic("not implemented") -} - -func (s *MongoStore) Delete(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { - panic("not implemented") + if session.Options.MaxAge < 0 { + if e := s.remove(session); e != nil { + return e + } + http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) + } else { + if session.ID == "" { + session.ID = primitive.NewObjectID().Hex() + } + if e := s.save(session); e != nil { + return e + } + encoded, e := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...) + if e != nil { + return e + } + http.SetCookie(w, sessions.NewCookie(session.Name(), encoded, session.Options)) + } + return nil } func (s *MongoStore) Close() error { - ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) + ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout) return s.client.Disconnect(ctx) } @@ -82,8 +98,56 @@ func (s *MongoStore) load(session *sessions.Session) error { if e := result.Decode(ns); e != nil { return e } - if e := securecookie.DecodeMulti(session.Name(), string(ns.Data), &session.Values, s.Codecs...); e != nil { + if e := securecookie.DecodeMulti(session.Name(), ns.Data, &session.Values, s.Codecs...); e != nil { return e } return nil } + +func (s *MongoStore) save(session *sessions.Session) error { + oid, e := primitive.ObjectIDFromHex(session.ID) + if e != nil { + return e + } + collection := s.client.Database(s.dbDatabaseName).Collection(s.dbCollectionName) + ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout) + + var modified time.Time + if val, ok := session.Values["modified"]; ok { + modified, ok = val.(time.Time) + if !ok { + return errors.New("mongostore: invalid modified value") + } + } else { + modified = time.Now() + } + encoded, e := securecookie.EncodeMulti(session.Name(), session.Values, s.Codecs...) + if e != nil { + return e + } + ns := new(Session) + ns.ID = oid + ns.Data = encoded + ns.Modified = modified + if _, e := collection.InsertOne(ctx, ns); e != nil { + return e + } + return nil +} + +func (s *MongoStore) remove(session *sessions.Session) error { + oid, e := primitive.ObjectIDFromHex(session.ID) + if e != nil { + return e + } + collection := s.client.Database(s.dbDatabaseName).Collection(s.dbCollectionName) + ctx, _ := context.WithTimeout(context.Background(), s.dbTimeout) + dr, e := collection.DeleteOne(ctx, bson.M{"_id": oid}) + if e != nil { + return e + } + if dr.DeletedCount != 1 { + return errors.New(fmt.Sprintf("DeletedCount (%d) != 1", dr.DeletedCount)) + } + return nil +}