// Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package mongostore_test import ( "bytes" "context" "encoding/gob" "github.com/dalu/mongostore" "github.com/gorilla/sessions" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "net/http" "net/http/httptest" "testing" "time" ) // NewRecorder returns an initialized ResponseRecorder. func NewRecorder() *httptest.ResponseRecorder { return &httptest.ResponseRecorder{ HeaderMap: make(http.Header), Body: new(bytes.Buffer), } } // ---------------------------------------------------------------------------- type FlashMessage struct { Type int Message string } func TestFlashes(t *testing.T) { var req *http.Request var rsp *httptest.ResponseRecorder var hdr http.Header var err error var ok bool var cookies []string var session *sessions.Session var flashes []interface{} ctx := context.Background() client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017")) err = client.Ping(ctx, readpref.Primary()) if err != nil { panic(err) } store := mongostore.NewMongoStore(client, "test", "sessions", 3*time.Second, []byte("hellol-worlol")) store.Options.Path = "/" store.Options.MaxAge = 86400 * 30 * 365 defer store.Close() // Round 1 ---------------------------------------------------------------- req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) } // Get a flash. flashes = session.Flashes() if len(flashes) != 0 { t.Errorf("Expected empty flashes; Got %v", flashes) } // Add some flashes. session.AddFlash("foo") session.AddFlash("bar") // Custom key. session.AddFlash("baz", "custom_key") // Save. if err = sessions.Save(req, rsp); err != nil { t.Fatalf("Error saving session: %v", err) } hdr = rsp.Header() cookies, ok = hdr["Set-Cookie"] if !ok || len(cookies) != 1 { t.Fatal("No cookies. Header:", hdr) } if _, err = store.Get(req, "session:key"); err.Error() != "sessions: invalid character in cookie name: session:key" { t.Fatalf("Expected error due to invalid cookie name") } // Round 2 ---------------------------------------------------------------- req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) req.Header.Add("Cookie", cookies[0]) rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) } // Check all saved values. flashes = session.Flashes() if len(flashes) != 2 { t.Fatalf("Expected flashes; Got %v", flashes) } if flashes[0] != "foo" || flashes[1] != "bar" { t.Errorf("Expected foo,bar; Got %v", flashes) } flashes = session.Flashes() if len(flashes) != 0 { t.Errorf("Expected dumped flashes; Got %v", flashes) } // Custom key. flashes = session.Flashes("custom_key") if len(flashes) != 1 { t.Errorf("Expected flashes; Got %v", flashes) } else if flashes[0] != "baz" { t.Errorf("Expected baz; Got %v", flashes) } flashes = session.Flashes("custom_key") if len(flashes) != 0 { t.Errorf("Expected dumped flashes; Got %v", flashes) } // Round 3 ---------------------------------------------------------------- // Custom type req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) } // Get a flash. flashes = session.Flashes() if len(flashes) != 0 { t.Errorf("Expected empty flashes; Got %v", flashes) } // Add some flashes. session.AddFlash(&FlashMessage{42, "foo"}) // Save. if err = sessions.Save(req, rsp); err != nil { t.Fatalf("Error saving session: %v", err) } hdr = rsp.Header() cookies, ok = hdr["Set-Cookie"] if !ok || len(cookies) != 1 { t.Fatal("No cookies. Header:", hdr) } // Round 4 ---------------------------------------------------------------- // Custom type req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) req.Header.Add("Cookie", cookies[0]) rsp = NewRecorder() // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) } // Check all saved values. flashes = session.Flashes() if len(flashes) != 1 { t.Fatalf("Expected flashes; Got %v", flashes) } custom := flashes[0].(FlashMessage) if custom.Type != 42 || custom.Message != "foo" { t.Errorf("Expected %#v, got %#v", FlashMessage{42, "foo"}, custom) } // Round 5 ---------------------------------------------------------------- // Check if a request shallow copy resets the request context data store. req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) // Get a session. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) } // Put a test value into the session data store. session.Values["test"] = "test-value" // Create a shallow copy of the request. req = req.WithContext(req.Context()) // Get the session again. if session, err = store.Get(req, "session-key"); err != nil { t.Fatalf("Error getting session: %v", err) } // Check if the previous inserted value still exists. if session.Values["test"] == nil { t.Fatalf("Session test value is lost in the request context!") } // Check if the previous inserted value has the same value. if session.Values["test"] != "test-value" { t.Fatalf("Session test value is changed in the request context!") } } func init() { gob.Register(FlashMessage{}) }