diff --git a/store.go b/store.go index ea964f2..c383475 100644 --- a/store.go +++ b/store.go @@ -154,7 +154,7 @@ func (s *MongoStore) save(session *sessions.Session) error { return e } } else { - if _, e := collection.UpdateOne(ctx, bson.M{"_id": oid}, ns); e != nil { + if _, e := collection.ReplaceOne(ctx, bson.M{"_id": oid}, ns); e != nil { return e } } diff --git a/store_test.go b/store_test.go index 2dfc005..8deb776 100644 --- a/store_test.go +++ b/store_test.go @@ -1,31 +1,40 @@ -// Copyright (c) 2013 Gregor Robinson. -// Copyright (c) 2013 Brian Jones. -// All rights reserved. -// Use of this source code is governed by a MIT-style +// Copyright 2012 The Gorilla Authors. All rights reserved. +// Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package mongostore +package mongostore_test import ( + "bytes" "context" "encoding/gob" - "net/http" - "net/http/httptest" - "testing" - "time" - + "github.com/dalu/mongostore" "github.com/gorilla/sessions" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" + "net/http" + "net/http/httptest" + "testing" + "time" ) +// NewRecorder returns an initialized ResponseRecorder. +func NewRecorder() *httptest.ResponseRecorder { + return &httptest.ResponseRecorder{ + HeaderMap: make(http.Header), + Body: new(bytes.Buffer), + } +} + +// ---------------------------------------------------------------------------- + type FlashMessage struct { Type int Message string } -func TestMongoStore(t *testing.T) { +func TestFlashes(t *testing.T) { var req *http.Request var rsp *httptest.ResponseRecorder var hdr http.Header @@ -35,11 +44,6 @@ func TestMongoStore(t *testing.T) { var session *sessions.Session var flashes []interface{} - // Copyright 2012 The Gorilla Authors. All rights reserved. - // Use of this source code is governed by a BSD-style - // license that can be found in the LICENSE file. - - // Round 1 ---------------------------------------------------------------- ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017")) err = client.Ping(ctx, readpref.Primary()) @@ -47,13 +51,15 @@ func TestMongoStore(t *testing.T) { panic(err) } - store := NewMongoStore(client, "test", "sessions", 3*time.Second) + store := mongostore.NewMongoStore(client, "test", "sessions", 3*time.Second) store.Options.Path = "/" store.Options.MaxAge = 86400 * 30 * 365 defer store.Close() + // Round 1 ---------------------------------------------------------------- + req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) - rsp = httptest.NewRecorder() + rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) @@ -75,14 +81,18 @@ func TestMongoStore(t *testing.T) { hdr = rsp.Header() cookies, ok = hdr["Set-Cookie"] if !ok || len(cookies) != 1 { - t.Fatalf("No cookies. Header: %v", hdr) + t.Fatal("No cookies. Header:", hdr) + } + + if _, err = store.Get(req, "session:key"); err.Error() != "sessions: invalid character in cookie name: session:key" { + t.Fatalf("Expected error due to invalid cookie name") } // Round 2 ---------------------------------------------------------------- req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) req.Header.Add("Cookie", cookies[0]) - rsp = httptest.NewRecorder() + rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) @@ -111,17 +121,11 @@ func TestMongoStore(t *testing.T) { t.Errorf("Expected dumped flashes; Got %v", flashes) } - session.Options.MaxAge = -1 - // Save. - if err = sessions.Save(req, rsp); err != nil { - t.Fatalf("Error saving session: %v", err) - } - // Round 3 ---------------------------------------------------------------- // Custom type req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) - rsp = httptest.NewRecorder() + rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) @@ -140,7 +144,7 @@ func TestMongoStore(t *testing.T) { hdr = rsp.Header() cookies, ok = hdr["Set-Cookie"] if !ok || len(cookies) != 1 { - t.Fatalf("No cookies. Header: %v", hdr) + t.Fatal("No cookies. Header:", hdr) } // Round 4 ---------------------------------------------------------------- @@ -148,7 +152,7 @@ func TestMongoStore(t *testing.T) { req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) req.Header.Add("Cookie", cookies[0]) - rsp = httptest.NewRecorder() + rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) @@ -163,11 +167,71 @@ func TestMongoStore(t *testing.T) { t.Errorf("Expected %#v, got %#v", FlashMessage{42, "foo"}, custom) } - // Delete session. - session.Options.MaxAge = -1 - // Save. - if err = sessions.Save(req, rsp); err != nil { - t.Fatalf("Error saving session: %v", err) + // Round 5 ---------------------------------------------------------------- + // Check if a request shallow copy resets the request context data store. + + req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) + + // Get a session. + if session, err = store.Get(req, "session-key"); err != nil { + t.Fatalf("Error getting session: %v", err) + } + + // Put a test value into the session data store. + session.Values["test"] = "test-value" + + // Create a shallow copy of the request. + req = req.WithContext(req.Context()) + + // Get the session again. + if session, err = store.Get(req, "session-key"); err != nil { + t.Fatalf("Error getting session: %v", err) + } + + // Check if the previous inserted value still exists. + if session.Values["test"] == nil { + t.Fatalf("Session test value is lost in the request context!") + } + + // Check if the previous inserted value has the same value. + if session.Values["test"] != "test-value" { + t.Fatalf("Session test value is changed in the request context!") + } +} + +func TestCookieStoreMapPanic(t *testing.T) { + defer func() { + err := recover() + if err != nil { + t.Fatal(err) + } + }() + + ctx, _ := context.WithTimeout(context.Background(), 3*time.Second) + client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017")) + err = client.Ping(ctx, readpref.Primary()) + if err != nil { + panic(err) + } + + store := mongostore.NewMongoStore(client, "test", "sessions", 3*time.Second, []byte("aaa0defe5d2839cbc46fc4f080cd7adc")) + store.Options.Path = "/" + store.Options.MaxAge = 86400 * 30 * 365 + defer store.Close() + + req, err := http.NewRequest("GET", "http://www.example.com", nil) + if err != nil { + t.Fatal("failed to create request", err) + } + w := httptest.NewRecorder() + + session := sessions.NewSession(store, "hello") + + session.Values["data"] = "hello-world" + + err = session.Save(req, w) + if err != nil { + t.Fatal("failed to save session", err) } }