diff --git a/go.mod b/go.mod index f9090d5..a9aabbb 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/dalu/mongostore go 1.14 require ( + github.com/gorilla/securecookie v1.1.1 github.com/gorilla/sessions v1.2.0 go.mongodb.org/mongo-driver v1.3.1 ) diff --git a/store.go b/store.go index 8ad8024..ea964f2 100644 --- a/store.go +++ b/store.go @@ -4,13 +4,13 @@ import ( "context" "errors" "fmt" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" "net/http" "time" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" ) @@ -24,31 +24,51 @@ type MongoStore struct { dbTimeout time.Duration } -func NewMongoStore(client *mongo.Client, dbDatabase, dbCollection string, dbTimeout time.Duration) *MongoStore { +func NewMongoStore(client *mongo.Client, dbDatabase, dbCollection string, dbTimeout time.Duration, keyPairs ...[]byte) *MongoStore { s := new(MongoStore) s.client = client s.dbDatabaseName = dbDatabase s.dbCollectionName = dbCollection s.dbTimeout = dbTimeout + s.Options = &sessions.Options{ + Path: "/", + MaxAge: 86400 * 30 * 365, + } + s.Codecs = securecookie.CodecsFromPairs(keyPairs...) return s } +func (s *MongoStore) MaxAge(age int) { + s.Options.MaxAge = age + + for _, codec := range s.Codecs { + if sc, ok := codec.(*securecookie.SecureCookie); ok { + sc.MaxAge(age) + } + } +} + func (s *MongoStore) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(s, name) - options := *s.Options - session.Options = &options + session.Options = &sessions.Options{ + Path: s.Options.Path, + MaxAge: s.Options.MaxAge, + Domain: s.Options.Domain, + Secure: s.Options.Secure, + HttpOnly: s.Options.HttpOnly, + } session.IsNew = true cookie, e := r.Cookie(name) if e != nil { - return nil, e + return session, e } e = securecookie.DecodeMulti(name, cookie.Value, &session.ID, s.Codecs...) if e != nil { - return nil, e + return session, e } e = s.load(session) if e != nil { - return nil, e + return session, e } else { session.IsNew = false } @@ -129,8 +149,14 @@ func (s *MongoStore) save(session *sessions.Session) error { ns.ID = oid ns.Data = encoded ns.Modified = modified - if _, e := collection.InsertOne(ctx, ns); e != nil { - return e + if session.IsNew { + if _, e := collection.InsertOne(ctx, ns); e != nil { + return e + } + } else { + if _, e := collection.UpdateOne(ctx, bson.M{"_id": oid}, ns); e != nil { + return e + } } return nil } diff --git a/store_test.go b/store_test.go new file mode 100644 index 0000000..adabf27 --- /dev/null +++ b/store_test.go @@ -0,0 +1,176 @@ +// Copyright (c) 2013 Gregor Robinson. +// Copyright (c) 2013 Brian Jones. +// All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package mongostore + +import ( + "context" + "encoding/gob" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readpref" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/gorilla/sessions" +) + +type FlashMessage struct { + Type int + Message string +} + +func TestMongoStore(t *testing.T) { + var req *http.Request + var rsp *httptest.ResponseRecorder + var hdr http.Header + var err error + var ok bool + var cookies []string + var session *sessions.Session + var flashes []interface{} + + // Copyright 2012 The Gorilla Authors. All rights reserved. + // Use of this source code is governed by a BSD-style + // license that can be found in the LICENSE file. + + // Round 1 ---------------------------------------------------------------- + ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) + client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017")) + err = client.Ping(ctx, readpref.Primary()) + if err != nil { + panic(err) + } + + store := NewMongoStore(client, "test", "sessions", 3*time.Second) + store.Options.Path = "/" + store.Options.MaxAge = 86400 * 30 * 365 + defer store.Close() + + req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) + rsp = httptest.NewRecorder() + // Get a session. + if session, err = store.Get(req, "session-key"); err != nil { + t.Fatalf("Error getting session: %v", err) + } + // Get a flash. + flashes = session.Flashes() + if len(flashes) != 0 { + t.Errorf("Expected empty flashes; Got %v", flashes) + } + // Add some flashes. + session.AddFlash("foo") + session.AddFlash("bar") + // Custom key. + session.AddFlash("baz", "custom_key") + // Save. + if err = sessions.Save(req, rsp); err != nil { + t.Fatalf("Error saving session: %v", err) + } + hdr = rsp.Header() + cookies, ok = hdr["Set-Cookie"] + if !ok || len(cookies) != 1 { + t.Fatalf("No cookies. Header: %v", hdr) + } + + // Round 2 ---------------------------------------------------------------- + + req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) + req.Header.Add("Cookie", cookies[0]) + rsp = httptest.NewRecorder() + // Get a session. + if session, err = store.Get(req, "session-key"); err != nil { + t.Fatalf("Error getting session: %v", err) + } + // Check all saved values. + flashes = session.Flashes() + if len(flashes) != 2 { + t.Fatalf("Expected flashes; Got %v", flashes) + } + if flashes[0] != "foo" || flashes[1] != "bar" { + t.Errorf("Expected foo,bar; Got %v", flashes) + } + flashes = session.Flashes() + if len(flashes) != 0 { + t.Errorf("Expected dumped flashes; Got %v", flashes) + } + // Custom key. + flashes = session.Flashes("custom_key") + if len(flashes) != 1 { + t.Errorf("Expected flashes; Got %v", flashes) + } else if flashes[0] != "baz" { + t.Errorf("Expected baz; Got %v", flashes) + } + flashes = session.Flashes("custom_key") + if len(flashes) != 0 { + t.Errorf("Expected dumped flashes; Got %v", flashes) + } + + session.Options.MaxAge = -1 + // Save. + if err = sessions.Save(req, rsp); err != nil { + t.Fatalf("Error saving session: %v", err) + } + + // Round 3 ---------------------------------------------------------------- + // Custom type + + req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) + rsp = httptest.NewRecorder() + // Get a session. + if session, err = store.Get(req, "session-key"); err != nil { + t.Fatalf("Error getting session: %v", err) + } + // Get a flash. + flashes = session.Flashes() + if len(flashes) != 0 { + t.Errorf("Expected empty flashes; Got %v", flashes) + } + // Add some flashes. + session.AddFlash(&FlashMessage{42, "foo"}) + // Save. + if err = sessions.Save(req, rsp); err != nil { + t.Fatalf("Error saving session: %v", err) + } + hdr = rsp.Header() + cookies, ok = hdr["Set-Cookie"] + if !ok || len(cookies) != 1 { + t.Fatalf("No cookies. Header: %v", hdr) + } + + // Round 4 ---------------------------------------------------------------- + // Custom type + + req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) + req.Header.Add("Cookie", cookies[0]) + rsp = httptest.NewRecorder() + // Get a session. + if session, err = store.Get(req, "session-key"); err != nil { + t.Fatalf("Error getting session: %v", err) + } + // Check all saved values. + flashes = session.Flashes() + if len(flashes) != 1 { + t.Fatalf("Expected flashes; Got %v", flashes) + } + custom := flashes[0].(FlashMessage) + if custom.Type != 42 || custom.Message != "foo" { + t.Errorf("Expected %#v, got %#v", FlashMessage{42, "foo"}, custom) + } + + // Delete session. + session.Options.MaxAge = -1 + // Save. + if err = sessions.Save(req, rsp); err != nil { + t.Fatalf("Error saving session: %v", err) + } +} + +func init() { + gob.Register(FlashMessage{}) +}