mongostore/store_test.go
2020-04-02 21:32:02 +02:00

205 lines
5.7 KiB
Go

// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package mongostore_test
import (
"bytes"
"context"
"encoding/gob"
"github.com/dalu/mongostore"
"github.com/gorilla/sessions"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"net/http"
"net/http/httptest"
"testing"
"time"
)
// NewRecorder returns an initialized ResponseRecorder.
func NewRecorder() *httptest.ResponseRecorder {
return &httptest.ResponseRecorder{
HeaderMap: make(http.Header),
Body: new(bytes.Buffer),
}
}
// ----------------------------------------------------------------------------
type FlashMessage struct {
Type int
Message string
}
func TestFlashes(t *testing.T) {
var req *http.Request
var rsp *httptest.ResponseRecorder
var hdr http.Header
var err error
var ok bool
var cookies []string
var session *sessions.Session
var flashes []interface{}
ctx := context.Background()
client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017"))
err = client.Ping(ctx, readpref.Primary())
if err != nil {
panic(err)
}
store := mongostore.NewMongoStore(client, "test", "sessions", 3*time.Second, []byte("hellol-worlol"))
store.Options.Path = "/"
store.Options.MaxAge = 86400 * 30 * 365
defer store.Close()
// Round 1 ----------------------------------------------------------------
req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
rsp = NewRecorder()
// Get a session.
if session, err = store.Get(req, "session-key"); err != nil {
t.Fatalf("Error getting session: %v", err)
}
// Get a flash.
flashes = session.Flashes()
if len(flashes) != 0 {
t.Errorf("Expected empty flashes; Got %v", flashes)
}
// Add some flashes.
session.AddFlash("foo")
session.AddFlash("bar")
// Custom key.
session.AddFlash("baz", "custom_key")
// Save.
if err = sessions.Save(req, rsp); err != nil {
t.Fatalf("Error saving session: %v", err)
}
hdr = rsp.Header()
cookies, ok = hdr["Set-Cookie"]
if !ok || len(cookies) != 1 {
t.Fatal("No cookies. Header:", hdr)
}
if _, err = store.Get(req, "session:key"); err.Error() != "sessions: invalid character in cookie name: session:key" {
t.Fatalf("Expected error due to invalid cookie name")
}
// Round 2 ----------------------------------------------------------------
req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
req.Header.Add("Cookie", cookies[0])
rsp = NewRecorder()
// Get a session.
if session, err = store.Get(req, "session-key"); err != nil {
t.Fatalf("Error getting session: %v", err)
}
// Check all saved values.
flashes = session.Flashes()
if len(flashes) != 2 {
t.Fatalf("Expected flashes; Got %v", flashes)
}
if flashes[0] != "foo" || flashes[1] != "bar" {
t.Errorf("Expected foo,bar; Got %v", flashes)
}
flashes = session.Flashes()
if len(flashes) != 0 {
t.Errorf("Expected dumped flashes; Got %v", flashes)
}
// Custom key.
flashes = session.Flashes("custom_key")
if len(flashes) != 1 {
t.Errorf("Expected flashes; Got %v", flashes)
} else if flashes[0] != "baz" {
t.Errorf("Expected baz; Got %v", flashes)
}
flashes = session.Flashes("custom_key")
if len(flashes) != 0 {
t.Errorf("Expected dumped flashes; Got %v", flashes)
}
// Round 3 ----------------------------------------------------------------
// Custom type
req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
rsp = NewRecorder()
// Get a session.
if session, err = store.Get(req, "session-key"); err != nil {
t.Fatalf("Error getting session: %v", err)
}
// Get a flash.
flashes = session.Flashes()
if len(flashes) != 0 {
t.Errorf("Expected empty flashes; Got %v", flashes)
}
// Add some flashes.
session.AddFlash(&FlashMessage{42, "foo"})
// Save.
if err = sessions.Save(req, rsp); err != nil {
t.Fatalf("Error saving session: %v", err)
}
hdr = rsp.Header()
cookies, ok = hdr["Set-Cookie"]
if !ok || len(cookies) != 1 {
t.Fatal("No cookies. Header:", hdr)
}
// Round 4 ----------------------------------------------------------------
// Custom type
req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
req.Header.Add("Cookie", cookies[0])
rsp = NewRecorder()
// Get a session.
if session, err = store.Get(req, "session-key"); err != nil {
t.Fatalf("Error getting session: %v", err)
}
// Check all saved values.
flashes = session.Flashes()
if len(flashes) != 1 {
t.Fatalf("Expected flashes; Got %v", flashes)
}
custom := flashes[0].(FlashMessage)
if custom.Type != 42 || custom.Message != "foo" {
t.Errorf("Expected %#v, got %#v", FlashMessage{42, "foo"}, custom)
}
// Round 5 ----------------------------------------------------------------
// Check if a request shallow copy resets the request context data store.
req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
// Get a session.
if session, err = store.Get(req, "session-key"); err != nil {
t.Fatalf("Error getting session: %v", err)
}
// Put a test value into the session data store.
session.Values["test"] = "test-value"
// Create a shallow copy of the request.
req = req.WithContext(req.Context())
// Get the session again.
if session, err = store.Get(req, "session-key"); err != nil {
t.Fatalf("Error getting session: %v", err)
}
// Check if the previous inserted value still exists.
if session.Values["test"] == nil {
t.Fatalf("Session test value is lost in the request context!")
}
// Check if the previous inserted value has the same value.
if session.Values["test"] != "test-value" {
t.Fatalf("Session test value is changed in the request context!")
}
}
func init() {
gob.Register(FlashMessage{})
}