diff --git a/api/comment.go b/api/comment.go new file mode 100644 index 0000000..e15a04e --- /dev/null +++ b/api/comment.go @@ -0,0 +1,19 @@ +package main + +import ( + "time" +) + +type comment struct { + CommentHex string `json:"commentHex"` + Domain string `json:"domain,omitempty"` + Path string `json:"url,omitempty"` + CommenterHex string `json:"commenterHex"` + Markdown string `json:"markdown"` + Html string `json:"html"` + ParentHex string `json:"parentHex"` + Score int `json:"score"` + State string `json:"-"` + CreationDate time.Time `json:"creationDate"` + VoteDirection int `json:"voteDirection"` +} diff --git a/api/comment_approve.go b/api/comment_approve.go new file mode 100644 index 0000000..51b76d1 --- /dev/null +++ b/api/comment_approve.go @@ -0,0 +1,68 @@ +package main + +import ( + "net/http" +) + +func commentApprove(commentHex string) error { + if commentHex == "" { + return errorMissingField + } + + statement := ` + UPDATE comments + SET state = 'approved' + WHERE commentHex = $1; + ` + + _, err := db.Exec(statement, commentHex) + if err != nil { + logger.Errorf("cannot approve comment: %v", err) + return errorInternal + } + + return nil +} + +func commentApproveHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + CommentHex *string `json:"commentHex"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + c, err := commenterGetBySession(*x.Session) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + domain, err := commentDomainGet(*x.CommentHex) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + isModerator, err := isDomainModerator(c.Email, domain) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if !isModerator { + writeBody(w, response{"success": false, "message": errorNotModerator.Error()}) + return + } + + if err = commentApprove(*x.CommentHex); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true}) +} diff --git a/api/comment_approve_test.go b/api/comment_approve_test.go new file mode 100644 index 0000000..c5594f2 --- /dev/null +++ b/api/comment_approve_test.go @@ -0,0 +1,33 @@ +package main + +import ( + "testing" + "time" +) + +func TestCommentApproveBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commenterHex, _ := commenterNew("test@example.com", "Test", "undefined", "https://example.com/photo.jpg", "google") + + commentHex, _ := commentNew(commenterHex, "example.com", "/path.html", "root", "**foo**", "unapproved", time.Now().UTC()) + + if err := commentApprove(commentHex); err != nil { + t.Errorf("unexpected error approving comment: %v", err) + return + } + + if c, _, _ := commentList("anonymous", "example.com", "/path.html", false); c[0].State != "approved" { + t.Errorf("expected state = approved got state = %s", c[0].State) + return + } +} + +func TestCommentApproveEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if err := commentApprove(""); err == nil { + t.Errorf("expected error not found approving comment with empty commentHex") + return + } +} diff --git a/api/comment_delete.go b/api/comment_delete.go new file mode 100644 index 0000000..9b3b386 --- /dev/null +++ b/api/comment_delete.go @@ -0,0 +1,67 @@ +package main + +import ( + "net/http" +) + +func commentDelete(commentHex string) error { + if commentHex == "" { + return errorMissingField + } + + statement := ` + DELETE FROM comments + WHERE commentHex=$1; + ` + _, err := db.Exec(statement, commentHex) + + if err != nil { + // TODO: make sure this is the error is actually non-existant commentHex + return errorNoSuchComment + } + + return nil +} + +func commentDeleteHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + CommentHex *string `json:"commentHex"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + c, err := commenterGetBySession(*x.Session) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + domain, err := commentDomainGet(*x.CommentHex) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + isModerator, err := isDomainModerator(c.Email, domain) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if !isModerator { + writeBody(w, response{"success": false, "message": errorNotModerator.Error()}) + return + } + + if err = commentDelete(*x.CommentHex); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true}) +} diff --git a/api/comment_delete_test.go b/api/comment_delete_test.go new file mode 100644 index 0000000..14e77c7 --- /dev/null +++ b/api/comment_delete_test.go @@ -0,0 +1,34 @@ +package main + +import ( + "testing" + "time" +) + +func TestCommentDeleteBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commentHex, _ := commentNew("temp-commenter-hex", "example.com", "/path.html", "root", "**foo**", "approved", time.Now().UTC()) + commentNew("temp-commenter-hex", "example.com", "/path.html", commentHex, "**bar**", "approved", time.Now().UTC()) + + if err := commentDelete(commentHex); err != nil { + t.Errorf("unexpected error deleting comment: %v", err) + return + } + + c, _, _ := commentList("temp-commenter-hex", "example.com", "/path.html", false) + + if len(c) != 0 { + t.Errorf("expected no comments found %d comments", len(c)) + return + } +} + +func TestCommentDeleteEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if err := commentDelete(""); err == nil { + t.Errorf("expected error deleting comment with empty commentHex") + return + } +} diff --git a/api/comment_domain_get.go b/api/comment_domain_get.go new file mode 100644 index 0000000..b5f0afb --- /dev/null +++ b/api/comment_domain_get.go @@ -0,0 +1,24 @@ +package main + +import () + +func commentDomainGet(commentHex string) (string, error) { + if commentHex == "" { + return "", errorMissingField + } + + statement := ` + SELECT domain + FROM comments + WHERE commentHex = $1; + ` + row := db.QueryRow(statement, commentHex) + + var domain string + var err error + if err = row.Scan(&domain); err != nil { + return "", errorNoSuchDomain + } + + return domain, nil +} diff --git a/api/comment_domain_get_test.go b/api/comment_domain_get_test.go new file mode 100644 index 0000000..14d6506 --- /dev/null +++ b/api/comment_domain_get_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "testing" + "time" +) + +func TestCommentDomainGetBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commentHex, _ := commentNew("temp-commenter-hex", "example.com", "/path.html", "root", "**foo**", "approved", time.Now().UTC()) + + domain, err := commentDomainGet(commentHex) + if err != nil { + t.Errorf("unexpected error getting domain by hex: %v", err) + return + } + + if domain != "example.com" { + t.Errorf("expected domain = example.com got domain = %s", domain) + return + } +} + +func TestCommentDomainGetEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := commentDomainGet(""); err == nil { + t.Errorf("expected error not found getting domain with empty commentHex") + return + } +} diff --git a/api/comment_list.go b/api/comment_list.go new file mode 100644 index 0000000..73c19ae --- /dev/null +++ b/api/comment_list.go @@ -0,0 +1,151 @@ +package main + +import ( + "database/sql" + "net/http" +) + +func commentList(commenterHex string, domain string, path string, includeUnapproved bool) ([]comment, map[string]commenter, error) { + if commenterHex == "" || domain == "" || path == "" { + return nil, nil, errorMissingField + } + + statement := ` + SELECT commentHex, commenterHex, markdown, html, parentHex, score, state, creationDate + FROM comments + WHERE + comments.domain = $1 AND + comments.path = $2 + ` + + if !includeUnapproved { + if commenterHex == "anonymous" { + statement += ` + AND state = 'approved' + ` + } else { + statement += ` + AND (state = 'approved' OR commenterHex = $3) + ` + } + } + + statement += `;` + + var rows *sql.Rows + var err error + + if !includeUnapproved && commenterHex != "anonymous" { + rows, err = db.Query(statement, domain, path, commenterHex) + } else { + rows, err = db.Query(statement, domain, path) + } + + if err != nil { + logger.Errorf("cannot get comments: %v", err) + return nil, nil, errorInternal + } + defer rows.Close() + + commenters := make(map[string]commenter) + commenters["anonymous"] = commenter{CommenterHex: "anonymous", Email: "undefined", Name: "Anonymous", Link: "undefined", Photo: "undefined", Provider: "undefined"} + + comments := []comment{} + for rows.Next() { + c := comment{} + if err = rows.Scan(&c.CommentHex, &c.CommenterHex, &c.Markdown, &c.Html, &c.ParentHex, &c.Score, &c.State, &c.CreationDate); err != nil { + return nil, nil, errorInternal + } + + if commenterHex != "anonymous" { + statement = ` + SELECT direction + FROM votes + WHERE commentHex=$1 AND commenterHex=$2; + ` + row := db.QueryRow(statement, c.CommentHex, commenterHex) + + if err = row.Scan(&c.VoteDirection); err != nil { + // TODO: is the only error here that there is no such entry? + c.VoteDirection = 0 + } + } + + comments = append(comments, c) + + if _, ok := commenters[c.CommenterHex]; !ok { + commenters[c.CommenterHex], err = commenterGetByHex(c.CommenterHex) + if err != nil { + logger.Errorf("cannot retrieve commenter: %v", err) + return nil, nil, errorInternal + } + } + } + + return comments, commenters, nil +} + +func commentListHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + Domain *string `json:"domain"` + Path *string `json:"path"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + domain := stripDomain(*x.Domain) + path := *x.Path + + d, err := domainGet(domain) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + commenterHex := "anonymous" + isModerator := false + if *x.Session != "anonymous" { + c, err := commenterGetBySession(*x.Session) + if err != nil { + if err == errorNoSuchSession { + commenterHex = "anonymous" + } else { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + } else { + commenterHex = c.CommenterHex + } + + for _, mod := range d.Moderators { + if mod.Email == c.Email { + isModerator = true + break + } + } + } + + domainViewRecord(domain, commenterHex) + + comments, commenters, err := commentList(commenterHex, domain, path, isModerator) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{ + "success": true, + "domain": domain, + "comments": comments, + "commenters": commenters, + "requireModeration": d.RequireModeration, + "requireIdentification": d.RequireIdentification, + "isFrozen": d.State == "frozen", + "isModerator": isModerator, + }) +} diff --git a/api/comment_list_test.go b/api/comment_list_test.go new file mode 100644 index 0000000..e3ca0f1 --- /dev/null +++ b/api/comment_list_test.go @@ -0,0 +1,154 @@ +package main + +import ( + "strings" + "testing" + "time" +) + +func TestCommentListBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commenterHex, _ := commenterNew("test@example.com", "Test", "undefined", "http://example.com/photo.jpg", "google") + + commentNew(commenterHex, "example.com", "/path.html", "root", "**foo**", "approved", time.Now().UTC()) + commentNew(commenterHex, "example.com", "/path.html", "root", "**bar**", "approved", time.Now().UTC()) + + c, _, err := commentList("temp-commenter-hex", "example.com", "/path.html", false) + if err != nil { + t.Errorf("unexpected error listing page comments: %v", err) + return + } + + if len(c) != 2 { + t.Errorf("expected 2 comments got %d comments", len(c)) + return + } + + if c[0].VoteDirection != 0 { + t.Errorf("expected c.VoteDirection = 0 got c.VoteDirection = %d", c[0].VoteDirection) + return + } + + c1Html := strings.TrimSpace(c[1].Html) + if c1Html != "

bar

" { + t.Errorf("expected c[1].Html=[

bar

] got c[1].Html=[%s]", c1Html) + return + } + + c, _, err = commentList(commenterHex, "example.com", "/path.html", false) + if err != nil { + t.Errorf("unexpected error listing page comments: %v", err) + return + } + + if len(c) != 2 { + t.Errorf("expected 2 comments got %d comments", len(c)) + return + } + + if c[0].VoteDirection != 1 { + t.Errorf("expected c.VoteDirection = 1 got c.VoteDirection = %d", c[0].VoteDirection) + return + } +} + +func TestCommentListEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, _, err := commentList("temp-commenter-hex", "", "/path.html", false); err == nil { + t.Errorf("expected error not found listing comments with empty domain") + return + } +} + +func TestCommentListSelfUnapproved(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commenterHex, _ := commenterNew("test@example.com", "Test", "undefined", "http://example.com/photo.jpg", "google") + + commentNew(commenterHex, "example.com", "/path.html", "root", "**foo**", "unapproved", time.Now().UTC()) + + c, _, _ := commentList("temp-commenter-hex", "example.com", "/path.html", false) + + if len(c) != 0 { + t.Errorf("expected user to not see unapproved comment") + return + } + + c, _, _ = commentList(commenterHex, "example.com", "/path.html", false) + + if len(c) != 1 { + t.Errorf("expected user to see unapproved self comment") + return + } +} + +func TestCommentListAnonymousUnapproved(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commentNew("anonymous", "example.com", "/path.html", "root", "**foo**", "unapproved", time.Now().UTC()) + + c, _, _ := commentList("anonymous", "example.com", "/path.html", false) + + if len(c) != 0 { + t.Errorf("expected user to not see unapproved anonymous comment as anonymous") + return + } +} + +func TestCommentListIncludeUnapproved(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commentNew("anonymous", "example.com", "/path.html", "root", "**foo**", "unapproved", time.Now().UTC()) + + c, _, _ := commentList("anonymous", "example.com", "/path.html", true) + + if len(c) != 1 { + t.Errorf("expected to see unapproved comments because includeUnapproved was true") + return + } +} + +func TestCommentListDifferentPaths(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commentNew("anonymous", "example.com", "/path1.html", "root", "**foo**", "unapproved", time.Now().UTC()) + commentNew("anonymous", "example.com", "/path1.html", "root", "**foo**", "unapproved", time.Now().UTC()) + commentNew("anonymous", "example.com", "/path2.html", "root", "**foo**", "unapproved", time.Now().UTC()) + + c, _, _ := commentList("anonymous", "example.com", "/path1.html", true) + + if len(c) != 2 { + t.Errorf("expected len(c) = 2 got len(c) = %d", len(c)) + return + } + + c, _, _ = commentList("anonymous", "example.com", "/path2.html", true) + + if len(c) != 1 { + t.Errorf("expected len(c) = 1 got len(c) = %d", len(c)) + return + } +} + +func TestCommentListDifferentDomains(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commentNew("anonymous", "example1.com", "/path.html", "root", "**foo**", "unapproved", time.Now().UTC()) + commentNew("anonymous", "example2.com", "/path.html", "root", "**foo**", "unapproved", time.Now().UTC()) + + c, _, _ := commentList("anonymous", "example1.com", "/path.html", true) + + if len(c) != 1 { + t.Errorf("expected len(c) = 1 got len(c) = %d", len(c)) + return + } + + c, _, _ = commentList("anonymous", "example2.com", "/path.html", true) + + if len(c) != 1 { + t.Errorf("expected len(c) = 1 got len(c) = %d", len(c)) + return + } +} diff --git a/api/comment_new.go b/api/comment_new.go new file mode 100644 index 0000000..88c8edf --- /dev/null +++ b/api/comment_new.go @@ -0,0 +1,120 @@ +package main + +import ( + "net/http" + "time" +) + +// Take `creationDate` as a param because comment import (from Disqus, for +// example) will require a custom time. +func commentNew(commenterHex string, domain string, path string, parentHex string, markdown string, state string, creationDate time.Time) (string, error) { + // path is allowed to be empty + if commenterHex == "" || domain == "" || parentHex == "" || markdown == "" || state == "" { + return "", errorMissingField + } + + commentHex, err := randomHex(32) + if err != nil { + return "", err + } + + html := markdownToHtml(markdown) + + statement := ` + INSERT INTO + comments (commentHex, domain, path, commenterHex, parentHex, markdown, html, creationDate, state) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9 ); + ` + _, err = db.Exec(statement, commentHex, domain, path, commenterHex, parentHex, markdown, html, creationDate, state) + if err != nil { + logger.Errorf("cannot insert comment: %v", err) + return "", errorInternal + } + + if err = commentVote(commenterHex, commentHex, 1); err != nil { + logger.Warningf("error: cannot upvote new comment automatically: %v", err) + } + + return commentHex, nil +} + +func commentNewHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + Domain *string `json:"domain"` + Path *string `json:"path"` + ParentHex *string `json:"parentHex"` + Markdown *string `json:"markdown"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + domain := stripDomain(*x.Domain) + path := *x.Path + + d, err := domainGet(domain) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if d.State == "frozen" { + writeBody(w, response{"success": false, "message": errorDomainFrozen.Error()}) + return + } + + // logic: (empty column indicates the value doesn't matter) + // | anonymous | moderator | requireIdentification | requireModeration | approved? | + // |-----------+-----------+-----------------------+-------------------+-----------| + // | yes | | | | no | + // | no | yes | | | yes | + // | no | no | | yes | yes | + // | no | no | | no | no | + + var commenterHex string + var state string + + if *x.Session == "anonymous" { + state = "unapproved" + commenterHex = "anonymous" + } else { + c, err := commenterGetBySession(*x.Session) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + // cheaper than a SQL query as we already have this information + isModerator := false + for _, mod := range d.Moderators { + if mod.Email == c.Email { + isModerator = true + break + } + } + + commenterHex = c.CommenterHex + + if isModerator { + state = "approved" + } else { + if d.RequireModeration { + state = "unapproved" + } else { + state = "approved" + } + } + } + + commentHex, err := commentNew(commenterHex, domain, path, *x.ParentHex, *x.Markdown, state, time.Now().UTC()) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true, "commentHex": commentHex}) +} diff --git a/api/comment_new_test.go b/api/comment_new_test.go new file mode 100644 index 0000000..3d29815 --- /dev/null +++ b/api/comment_new_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "testing" + "time" +) + +func TestCommentNewBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := commentNew("temp-commenter-hex", "example.com", "/path.html", "root", "**foo**", "approved", time.Now().UTC()); err != nil { + t.Errorf("unexpected error creating new comment: %v", err) + return + } +} + +func TestCommentNewEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := commentNew("temp-commenter-hex", "example.com", "", "root", "**foo**", "approved", time.Now().UTC()); err != nil { + t.Errorf("empty path not allowed: %v", err) + return + } + + if _, err := commentNew("temp-commenter-hex", "", "", "root", "**foo**", "approved", time.Now().UTC()); err == nil { + t.Errorf("expected error not found creatingn new comment with empty domain") + return + } + + if _, err := commentNew("", "", "", "", "", "", time.Now().UTC()); err == nil { + t.Errorf("expected error not found creatingn new comment with empty everything") + return + } +} + +func TestCommentNewUpvoted(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commentHex, _ := commentNew("temp-commenter-hex", "example.com", "/path.html", "root", "**foo**", "approved", time.Now().UTC()) + + statement := ` + SELECT score + FROM comments + WHERE commentHex = $1; + ` + row := db.QueryRow(statement, commentHex) + + var score int + if err := row.Scan(&score); err != nil { + t.Errorf("error scanning score from comments table: %v", err) + return + } + + if score != 1 { + t.Errorf("expected comment to be auto-upvoted") + return + } +} diff --git a/api/comment_ownership_verify.go b/api/comment_ownership_verify.go new file mode 100644 index 0000000..2b9b693 --- /dev/null +++ b/api/comment_ownership_verify.go @@ -0,0 +1,26 @@ +package main + +import () + +func commentOwnershipVerify(commenterHex string, commentHex string) (bool, error) { + if commenterHex == "" || commentHex == "" { + return false, errorMissingField + } + + statement := ` + SELECT EXISTS ( + SELECT 1 + FROM comments + WHERE commenterHex=$1 AND commentHex=$2 + ); + ` + row := db.QueryRow(statement, commenterHex, commentHex) + + var exists bool + if err := row.Scan(&exists); err != nil { + logger.Errorf("cannot query if comment owner: %v", err) + return false, errorInternal + } + + return exists, nil +} diff --git a/api/comment_ownership_verify_test.go b/api/comment_ownership_verify_test.go new file mode 100644 index 0000000..2245b0d --- /dev/null +++ b/api/comment_ownership_verify_test.go @@ -0,0 +1,43 @@ +package main + +import ( + "testing" + "time" +) + +func TestCommentOwnershipVerifyBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commentHex, _ := commentNew("temp-commenter-hex", "example.com", "/path.html", "root", "**foo**", "approved", time.Now().UTC()) + + isOwner, err := commentOwnershipVerify("temp-commenter-hex", commentHex) + if err != nil { + t.Errorf("unexpected error verifying ownership: %v", err) + return + } + + if !isOwner { + t.Errorf("expected to be owner of comment") + return + } + + isOwner, err = commentOwnershipVerify("another-commenter-hex", commentHex) + if err != nil { + t.Errorf("unexpected error verifying ownership: %v", err) + return + } + + if isOwner { + t.Errorf("unexpected owner of comment not created by another-commenter-hex") + return + } +} + +func TestCommentOwnershipVerifyEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := commentOwnershipVerify("temp-commenter-hex", ""); err == nil { + t.Errorf("expected error not founding verifying ownership with empty commentHex") + return + } +} diff --git a/api/comment_vote.go b/api/comment_vote.go new file mode 100644 index 0000000..591c6d1 --- /dev/null +++ b/api/comment_vote.go @@ -0,0 +1,66 @@ +package main + +import ( + "net/http" + "time" +) + +func commentVote(commenterHex string, commentHex string, direction int) error { + if commentHex == "" || commenterHex == "" { + return errorMissingField + } + + statement := ` + INSERT INTO + votes (commentHex, commenterHex, direction, voteDate) + VALUES ($1, $2, $3, $4 ) + ON CONFLICT (commentHex, commenterHex) DO + UPDATE SET direction = $3; + ` + _, err := db.Exec(statement, commentHex, commenterHex, direction, time.Now().UTC()) + if err != nil { + logger.Errorf("error inserting/updating votes: %v", err) + return errorInternal + } + + return nil +} + +func commentVoteHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + CommentHex *string `json:"commentHex"` + Direction *int `json:"direction"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if *x.Session == "anonymous" { + writeBody(w, response{"success": false, "message": errorUnauthorisedVote.Error()}) + return + } + + c, err := commenterGetBySession(*x.Session) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + direction := 0 + if *x.Direction > 0 { + direction = 1 + } else if *x.Direction < 0 { + direction = -1 + } + + if err := commentVote(c.CommenterHex, *x.CommentHex, direction); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true}) +} diff --git a/api/comment_vote_test.go b/api/comment_vote_test.go new file mode 100644 index 0000000..74bd22d --- /dev/null +++ b/api/comment_vote_test.go @@ -0,0 +1,55 @@ +package main + +import ( + "testing" + "time" +) + +func TestCommentVoteBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + cr0, _ := commenterNew("test1@example.com", "Test1", "undefined", "http://example.com/photo.jpg", "google") + cr1, _ := commenterNew("test2@example.com", "Test2", "undefined", "http://example.com/photo.jpg", "google") + cr2, _ := commenterNew("test3@example.com", "Test3", "undefined", "http://example.com/photo.jpg", "google") + + c0, _ := commentNew(cr0, "example.com", "/path.html", "root", "**foo**", "approved", time.Now().UTC()) + + commentVote(cr0, c0, -1) + if c, _, _ := commentList("temp", "example.com", "/path.html", false); c[0].Score != -1 { + t.Errorf("expected c[0].Score = -1 got c[0].Score = %d", c[0].Score) + return + } + + commentVote(cr1, c0, -1) + commentVote(cr2, c0, -1) + if c, _, _ := commentList("temp", "example.com", "/path.html", false); c[0].Score != -3 { + t.Errorf("expected c[0].Score = -3 got c[0].Score = %d", c[0].Score) + return + } + + commentVote(cr1, c0, -1) + if c, _, _ := commentList("temp", "example.com", "/path.html", false); c[0].Score != -3 { + t.Errorf("expected c[0].Score = -3 got c[0].Score = %d", c[0].Score) + return + } + + commentVote(cr1, c0, 0) + if c, _, _ := commentList("temp", "example.com", "/path.html", false); c[0].Score != -2 { + t.Errorf("expected c[0].Score = -2 got c[0].Score = %d", c[0].Score) + return + } + + c1, _ := commentNew(cr1, "example.com", "/path.html", "root", "**bar**", "approved", time.Now().UTC()) + + commentVote(cr0, c1, 0) + if c, _, _ := commentList("temp", "example.com", "/path.html", false); c[1].Score != 1 { + t.Errorf("expected c[1].Score = 1 got c[1].Score = %d", c[1].Score) + return + } + + commentVote(cr1, c1, 0) + if c, _, _ := commentList("temp", "example.com", "/path.html", false); c[1].Score != 0 { + t.Errorf("expected c[1].Score = 0 got c[1].Score = %d", c[1].Score) + return + } +} diff --git a/api/commenter.go b/api/commenter.go new file mode 100644 index 0000000..a3a795b --- /dev/null +++ b/api/commenter.go @@ -0,0 +1,38 @@ +package main + +import ( + "time" +) + +type commenter struct { + CommenterHex string `json:"commenterHex,omitempty"` + Email string `json:"email,omitempty"` + Name string `json:"name"` + Link string `json:"link"` + Photo string `json:"photo"` + Provider string `json:"provider,omitempty"` + JoinDate time.Time `json:"joinDate,omitempty"` +} + +func commenterIsProviderUser(provider string, email string) (bool, error) { + if provider == "" || email == "" { + return false, errorMissingField + } + + statement := ` + SELECT EXISTS ( + SELECT 1 + FROM commenters + WHERE email=$1 AND provider=$2 + ); + ` + row := db.QueryRow(statement, email, provider) + + var exists bool + if err := row.Scan(&exists); err != nil { + logger.Errorf("error checking if provider user exists: %v", err) + return false, errorInternal + } + + return exists, nil +} diff --git a/api/commenter_get.go b/api/commenter_get.go new file mode 100644 index 0000000..9c6f8d6 --- /dev/null +++ b/api/commenter_get.go @@ -0,0 +1,45 @@ +package main + +import () + +func commenterGetByHex(commenterHex string) (commenter, error) { + if commenterHex == "" { + return commenter{}, errorMissingField + } + + statement := ` + SELECT commenterHex, email, name, link, photo, provider, joinDate + FROM commenters + WHERE commenterHex=$1; + ` + row := db.QueryRow(statement, commenterHex) + + c := commenter{} + if err := row.Scan(&c.CommenterHex, &c.Email, &c.Name, &c.Link, &c.Photo, &c.Provider, &c.JoinDate); err != nil { + logger.Errorf("error scanning commenter: %v", err) + return commenter{}, errorInternal + } + + return c, nil +} + +func commenterGetBySession(session string) (commenter, error) { + if session == "" { + return commenter{}, errorMissingField + } + + statement := ` + SELECT commenterHex + FROM commenterSessions + WHERE session=$1; + ` + row := db.QueryRow(statement, session) + + var commenterHex string + if err := row.Scan(&commenterHex); err != nil { + // TODO: is the only error? + return commenter{}, errorNoSuchSession + } + + return commenterGetByHex(commenterHex) +} diff --git a/api/commenter_get_test.go b/api/commenter_get_test.go new file mode 100644 index 0000000..5e223dd --- /dev/null +++ b/api/commenter_get_test.go @@ -0,0 +1,61 @@ +package main + +import ( + "testing" +) + +func TestCommenterGetByHexBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commenterHex, _ := commenterNew("test@example.com", "Test", "undefined", "https://example.com/photo.jpg", "google") + + c, err := commenterGetByHex(commenterHex) + if err != nil { + t.Errorf("unexpected error getting commenter by hex: %v", err) + return + } + + if c.Name != "Test" { + t.Errorf("expected name=Test got name=%s", c.Name) + return + } +} + +func TestCommenterGetByHexEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := commenterGetByHex(""); err == nil { + t.Errorf("expected error not found getting commenter with empty hex") + return + } +} + +func TestCommenterGetBySession(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commenterHex, _ := commenterNew("test@example.com", "Test", "undefined", "https://example.com/photo.jpg", "google") + + session, _ := commenterSessionNew() + + commenterSessionUpdate(session, commenterHex) + + c, err := commenterGetBySession(session) + if err != nil { + t.Errorf("unexpected error getting commenter by hex: %v", err) + return + } + + if c.Name != "Test" { + t.Errorf("expected name=Test got name=%s", c.Name) + return + } +} + +func TestCommenterGetBySessionEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := commenterGetBySession(""); err == nil { + t.Errorf("expected error not found getting commenter with empty session") + return + } +} diff --git a/api/commenter_new.go b/api/commenter_new.go new file mode 100644 index 0000000..76e3d14 --- /dev/null +++ b/api/commenter_new.go @@ -0,0 +1,28 @@ +package main + +import ( + "time" +) + +func commenterNew(email string, name string, link string, photo string, provider string) (string, error) { + if email == "" || name == "" || link == "" || photo == "" || provider == "" { + return "", errorMissingField + } + + commenterHex, err := randomHex(32) + if err != nil { + return "", errorInternal + } + + statement := ` + INSERT INTO + commenters (commenterHex, email, name, link, photo, provider, joinDate) + VALUES ($1, $2, $3, $4, $5, $6, $7 ); + ` + _, err = db.Exec(statement, commenterHex, email, name, link, photo, provider, time.Now().UTC()) + if err != nil { + return "", errorInternal + } + + return commenterHex, nil +} diff --git a/api/commenter_new_test.go b/api/commenter_new_test.go new file mode 100644 index 0000000..e40b947 --- /dev/null +++ b/api/commenter_new_test.go @@ -0,0 +1,28 @@ +package main + +import ( + "testing" +) + +func TestCommenterNewBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := commenterNew("test@example.com", "Test", "undefined", "https://example.com/photo.jpg", "google"); err != nil { + t.Errorf("unexpected error creating new commenter: %v", err) + return + } +} + +func TestCommenterNewEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := commenterNew("", "Test", "undefined", "https://example.com/photo.jpg", "google"); err == nil { + t.Errorf("expected error not found creating new commenter with empty email") + return + } + + if _, err := commenterNew("", "", "", "", ""); err == nil { + t.Errorf("expected error not found creating new commenter with empty everything") + return + } +} diff --git a/api/commenter_self.go b/api/commenter_self.go new file mode 100644 index 0000000..804563d --- /dev/null +++ b/api/commenter_self.go @@ -0,0 +1,25 @@ +package main + +import ( + "net/http" +) + +func commenterSelfHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + c, err := commenterGetBySession(*x.Session) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true, "commenter": c}) +} diff --git a/api/commenter_session.go b/api/commenter_session.go new file mode 100644 index 0000000..77c933e --- /dev/null +++ b/api/commenter_session.go @@ -0,0 +1,11 @@ +package main + +import ( + "time" +) + +type commenterSession struct { + Session string `json:"session"` + CommenterHex string `json:"commenterHex"` + CreationDate time.Time `json:"creationDate"` +} diff --git a/api/commenter_session_get.go b/api/commenter_session_get.go new file mode 100644 index 0000000..ae9731e --- /dev/null +++ b/api/commenter_session_get.go @@ -0,0 +1,25 @@ +package main + +import () + +func commenterSessionGet(session string) (commenterSession, error) { + if session == "" { + return commenterSession{}, errorMissingField + } + + statement := ` + SELECT commenterHex, creationDate + FROM commenterSessions + WHERE session=$1; + ` + row := db.QueryRow(statement, session) + + cs := commenterSession{} + if err := row.Scan(&cs.CommenterHex, &cs.CreationDate); err != nil { + return commenterSession{}, errorNoSuchSession + } + + cs.Session = session + + return cs, nil +} diff --git a/api/commenter_session_get_test.go b/api/commenter_session_get_test.go new file mode 100644 index 0000000..9643593 --- /dev/null +++ b/api/commenter_session_get_test.go @@ -0,0 +1,46 @@ +package main + +import ( + "testing" +) + +func TestCommenterSessionGetBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commenterHex, _ := commenterNew("test@example.com", "Test", "undefined", "https://example.com/photo.jpg", "google") + + session, _ := commenterSessionNew() + + commenterSessionUpdate(session, commenterHex) + + cs, err := commenterSessionGet(session) + if err != nil { + t.Errorf("unexpected error found when getting session information: %v", err) + return + } + + if cs.CommenterHex != commenterHex { + t.Errorf("expected commenterHex=%s got commenterHex=%s", commenterHex, cs.CommenterHex) + return + } +} + +func TestCommenterSessionGetDNE(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + _, err := commenterSessionGet("does-not-exist") + if err == nil { + t.Errorf("expected error not found when invalid session") + return + } +} + +func TestCommenterSessionGetEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + _, err := commenterSessionGet("") + if err == nil { + t.Errorf("expected error not found with empty session") + return + } +} diff --git a/api/commenter_session_new.go b/api/commenter_session_new.go new file mode 100644 index 0000000..eba6581 --- /dev/null +++ b/api/commenter_session_new.go @@ -0,0 +1,37 @@ +package main + +import ( + "net/http" + "time" +) + +func commenterSessionNew() (string, error) { + session, err := randomHex(32) + if err != nil { + logger.Errorf("cannot create session hex: %v", err) + return "", errorInternal + } + + statement := ` + INSERT INTO + commenterSessions (session, creationDate) + VALUES ($1, $2 ); + ` + _, err = db.Exec(statement, session, time.Now().UTC()) + if err != nil { + logger.Errorf("cannot insert new session: %v", err) + return "", errorInternal + } + + return session, nil +} + +func commenterSessionNewHandler(w http.ResponseWriter, r *http.Request) { + session, err := commenterSessionNew() + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true, "session": session}) +} diff --git a/api/commenter_session_new_test.go b/api/commenter_session_new_test.go new file mode 100644 index 0000000..e058107 --- /dev/null +++ b/api/commenter_session_new_test.go @@ -0,0 +1,14 @@ +package main + +import ( + "testing" +) + +func TestCommenterSessionNewBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := commenterSessionNew(); err != nil { + t.Errorf("unexpected error creating new session: %v", err) + return + } +} diff --git a/api/commenter_session_update.go b/api/commenter_session_update.go new file mode 100644 index 0000000..a9ffb7c --- /dev/null +++ b/api/commenter_session_update.go @@ -0,0 +1,22 @@ +package main + +import () + +func commenterSessionUpdate(session string, commenterHex string) error { + if session == "" || commenterHex == "" { + return errorMissingField + } + + statement := ` + UPDATE commenterSessions + SET commenterHex=$2 + WHERE session=$1; + ` + _, err := db.Exec(statement, session, commenterHex) + if err != nil { + logger.Errorf("error updating commenterHex in commenterSessions: %v", err) + return errorInternal + } + + return nil +} diff --git a/api/commenter_session_update_test.go b/api/commenter_session_update_test.go new file mode 100644 index 0000000..cdea0af --- /dev/null +++ b/api/commenter_session_update_test.go @@ -0,0 +1,25 @@ +package main + +import ( + "testing" +) + +func TestCommenterSessionUpdateBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + session, _ := commenterSessionNew() + + if err := commenterSessionUpdate(session, "temp-commenter-hex"); err != nil { + t.Errorf("unexpected error updating session to commenterHex: %v", err) + return + } +} + +func TestCommenterSessionUpdateEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if err := commenterSessionUpdate("", "temp-commenter-hex"); err == nil { + t.Errorf("expected error not found when updating with empty session") + return + } +} diff --git a/api/commenter_test.go b/api/commenter_test.go new file mode 100644 index 0000000..c4aa047 --- /dev/null +++ b/api/commenter_test.go @@ -0,0 +1,42 @@ +package main + +import ( + "testing" +) + +func TestCommenterIsProviderUserBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + commenterNew("test@example.com", "Test", "undefined", "https://example.com/photo.jpg", "google") + + exists, err := commenterIsProviderUser("google", "test@example.com") + if err != nil { + t.Errorf("unexpected error checking if commenter is a provider user: %v", err) + return + } + + if !exists { + t.Errorf("user expected to exist not found") + return + } + + exists, err = commenterIsProviderUser("google", "test2@example.com") + if err != nil { + t.Errorf("unexpected error checking if commenter is a provider user: %v", err) + return + } + + if exists { + t.Errorf("user expected to not exist not found") + return + } +} + +func TestCommenterIsProviderUserEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := commenterIsProviderUser("google", ""); err == nil { + t.Errorf("expected error not found when checking for user with empty email") + return + } +} diff --git a/api/config.go b/api/config.go new file mode 100644 index 0000000..5d434d4 --- /dev/null +++ b/api/config.go @@ -0,0 +1,40 @@ +package main + +import ( + "os" +) + +func parseConfig() error { + defaults := map[string]string{ + "POSTGRES": "postgres://postgres:postgres@0.0.0.0/commento?sslmode=disable", + + "PORT": "8080", + "ORIGIN": "", + + "CDN_PREFIX": "", + + "SMTP_USERNAME": "", + "SMTP_PASSWORD": "", + "SMTP_HOST": "", + "SMTP_FROM_ADDRESS": "", + + "OAUTH_GOOGLE_KEY": "", + "OAUTH_GOOGLE_SECRET": "", + } + + for key, value := range defaults { + if os.Getenv(key) == "" { + os.Setenv(key, value) + } + } + + // Mandatory config parameters + for _, env := range []string{"POSTGRES", "PORT", "ORIGIN"} { + if os.Getenv(env) == "" { + logger.Fatalf("missing %s environment variable", env) + return errorMissingConfig + } + } + + return nil +} diff --git a/api/config_test.go b/api/config_test.go new file mode 100644 index 0000000..66279cc --- /dev/null +++ b/api/config_test.go @@ -0,0 +1,33 @@ +package main + +import ( + "os" + "testing" +) + +func TestParseConfigBasics(t *testing.T) { + os.Setenv("ORIGIN", "https://commento.io") + + if err := parseConfig(); err != nil { + t.Errorf("unexpected error when parsing config: %v", err) + return + } + + // This test feels kinda stupid, but whatever. + if os.Getenv("PORT") != "8080" { + t.Errorf("expected PORT=8080, but PORT=%s instead", os.Getenv("PORT")) + return + } + + os.Setenv("PORT", "1886") + + if err := parseConfig(); err != nil { + t.Errorf("unexpected error when parsing config: %v", err) + return + } + + if os.Getenv("PORT") != "1886" { + t.Errorf("expected PORT=1886, but PORT=%s instead", os.Getenv("PORT")) + return + } +} diff --git a/api/database.go b/api/database.go new file mode 100644 index 0000000..4371b0e --- /dev/null +++ b/api/database.go @@ -0,0 +1,7 @@ +package main + +import ( + "database/sql" +) + +var db *sql.DB diff --git a/api/database_connect.go b/api/database_connect.go new file mode 100644 index 0000000..110763e --- /dev/null +++ b/api/database_connect.go @@ -0,0 +1,39 @@ +package main + +import ( + "database/sql" + _ "github.com/lib/pq" + "os" +) + +func connectDB() error { + con := os.Getenv("POSTGRES") + logger.Infof("opening connection to postgres: %s", con) + + var err error + db, err = sql.Open("postgres", con) + if err != nil { + logger.Errorf("cannot open connection to postgres: %v", err) + return err + } + + statement := ` + CREATE TABLE IF NOT EXISTS migrations ( + filename TEXT NOT NULL UNIQUE + ); + ` + _, err = db.Exec(statement) + if err != nil { + logger.Errorf("cannot create migrations table: %v", err) + return err + } + + // At most 1000 database connections will be left open in the idle state. This + // was found to be important when benchmarking with `wrk`: if this was unset, + // too many open idle connections were present, resulting in dropped requests + // due to the limit on the number of file handles. On benchmarking, around + // 100 was found to be pretty optimal. + db.SetMaxIdleConns(100) + + return nil +} diff --git a/api/database_migrations.go b/api/database_migrations.go new file mode 100644 index 0000000..3d385b9 --- /dev/null +++ b/api/database_migrations.go @@ -0,0 +1,80 @@ +package main + +import ( + "io/ioutil" + "os" + "strings" +) + +func performMigrations() error { + return performMigrationsFromDir("db") +} + +func performMigrationsFromDir(dir string) error { + files, err := ioutil.ReadDir(dir) + if err != nil { + logger.Errorf("cannot read directory for migrations: %v", err) + return err + } + + statement := ` + SELECT filename + FROM migrations; + ` + rows, err := db.Query(statement) + if err != nil { + logger.Errorf("cannot query migrations: %v", err) + return err + } + + defer rows.Close() + + filenames := make(map[string]bool) + for rows.Next() { + var filename string + if err = rows.Scan(&filename); err != nil { + logger.Errorf("cannot scan filename: %v", err) + return err + } + + filenames[filename] = true + } + + completed := 0 + for _, file := range files { + if strings.HasSuffix(file.Name(), ".sql") { + if !filenames[file.Name()] { + f := dir + string(os.PathSeparator) + file.Name() + contents, err := ioutil.ReadFile(f) + if err != nil { + logger.Errorf("cannot read file %s: %v", file.Name(), err) + return err + } + + if _, err = db.Exec(string(contents)); err != nil { + logger.Errorf("cannot execute the SQL in %s: %v", f, err) + return err + } + + statement = ` + INSERT INTO + migrations (filename) + VALUES ($1 ); + ` + _, err = db.Exec(statement, file.Name()) + if err != nil { + logger.Errorf("cannot insert filename into the migrations table: %v", err) + return err + } + + completed++ + } + } + } + + if completed > 0 { + logger.Infof("%d migrations found, %d new migrations completed (%d total)", len(filenames), completed, len(filenames)+completed) + } + + return nil +} diff --git a/api/domain.go b/api/domain.go new file mode 100644 index 0000000..ebec5e5 --- /dev/null +++ b/api/domain.go @@ -0,0 +1,18 @@ +package main + +import ( + "time" +) + +type domain struct { + Domain string `json:"domain"` + OwnerHex string `json:"ownerHex"` + Name string `json:"name"` + CreationDate time.Time `json:"creationDate"` + State string `json:"state"` + ImportedComments bool `json:"importedComments"` + AutoSpamFilter bool `json:"autoSpamFilter"` + RequireModeration bool `json:"requireModeration"` + RequireIdentification bool `json:"requireIdentification"` + Moderators []moderator `json:"moderators"` +} diff --git a/api/domain_delete.go b/api/domain_delete.go new file mode 100644 index 0000000..2689c99 --- /dev/null +++ b/api/domain_delete.go @@ -0,0 +1,52 @@ +package main + +import () + +func domainDelete(domain string) error { + if domain == "" { + return errorMissingField + } + + statement := ` + DELETE FROM + domains + WHERE domain = $1; + ` + _, err := db.Exec(statement, domain) + if err != nil { + return errorNoSuchDomain + } + + statement = ` + DELETE FROM votes + USING comments + WHERE comments.commentHex = votes.commentHex AND comments.domain = $1; + ` + _, err = db.Exec(statement, domain) + if err != nil { + logger.Errorf("cannot delete votes: %v", err) + return errorInternal + } + + statement = ` + DELETE FROM views + WHERE views.domain = $1; + ` + _, err = db.Exec(statement, domain) + if err != nil { + logger.Errorf("cannot delete views: %v", err) + return errorInternal + } + + statement = ` + DELETE FROM comments + WHERE comments.domain = $1; + ` + _, err = db.Exec(statement, domain) + if err != nil { + logger.Errorf(statement, domain) + return errorInternal + } + + return nil +} diff --git a/api/domain_delete_test.go b/api/domain_delete_test.go new file mode 100644 index 0000000..410776e --- /dev/null +++ b/api/domain_delete_test.go @@ -0,0 +1,38 @@ +package main + +import ( + "testing" +) + +func TestDomainDeleteBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + domainNew("temp-owner-hex", "Example", "example.com") + domainNew("temp-owner-hex", "Example", "example2.com") + + if err := domainDelete("example.com"); err != nil { + t.Errorf("unexpected error deleting domain: %v", err) + return + } + + d, _ := domainList("temp-owner-hex") + + if len(d) != 1 { + t.Errorf("expected number of domains to be 1 got %d", len(d)) + return + } + + if d[0].Domain != "example2.com" { + t.Errorf("expected first domain to be example2.com got %s", d[0].Domain) + return + } +} + +func TestDomainDeleteEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if err := domainDelete(""); err == nil { + t.Errorf("expected error not found when deleting with empty domain") + return + } +} diff --git a/api/domain_get.go b/api/domain_get.go new file mode 100644 index 0000000..b881ecc --- /dev/null +++ b/api/domain_get.go @@ -0,0 +1,29 @@ +package main + +import () + +func domainGet(dmn string) (domain, error) { + if dmn == "" { + return domain{}, errorMissingField + } + + statement := ` + SELECT domain, ownerHex, name, creationDate, state, importedComments, autoSpamFilter, requireModeration, requireIdentification + FROM domains + WHERE domain = $1; + ` + row := db.QueryRow(statement, dmn) + + var err error + d := domain{} + if err = row.Scan(&d.Domain, &d.OwnerHex, &d.Name, &d.CreationDate, &d.State, &d.ImportedComments, &d.AutoSpamFilter, &d.RequireModeration, &d.RequireIdentification); err != nil { + return d, errorNoSuchDomain + } + + d.Moderators, err = domainModeratorList(d.Domain) + if err != nil { + return domain{}, err + } + + return d, nil +} diff --git a/api/domain_get_test.go b/api/domain_get_test.go new file mode 100644 index 0000000..f90ed53 --- /dev/null +++ b/api/domain_get_test.go @@ -0,0 +1,40 @@ +package main + +import ( + "testing" +) + +func TestDomainGetBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + domainNew("temp-owner-hex", "Example", "example.com") + + d, err := domainGet("example.com") + if err != nil { + t.Errorf("unexpected error getting domain: %v", err) + return + } + + if d.Name != "Example" { + t.Errorf("expected name=Example got name=%s", d.Name) + return + } +} + +func TestDomainGetEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := domainGet(""); err == nil { + t.Errorf("expected error not found when getting with empty domain") + return + } +} + +func TestDomainGetDNE(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := domainGet("example.com"); err == nil { + t.Errorf("expected error not found when getting non-existant domain") + return + } +} diff --git a/api/domain_list.go b/api/domain_list.go new file mode 100644 index 0000000..8752a8c --- /dev/null +++ b/api/domain_list.go @@ -0,0 +1,67 @@ +package main + +import ( + "net/http" +) + +func domainList(ownerHex string) ([]domain, error) { + if ownerHex == "" { + return []domain{}, errorMissingField + } + + statement := ` + SELECT domain, ownerHex, name, creationDate, state, importedComments, autoSpamFilter, requireModeration, requireIdentification + FROM domains + WHERE ownerHex=$1; + ` + rows, err := db.Query(statement, ownerHex) + if err != nil { + logger.Errorf("cannot query domains: %v", err) + return nil, errorInternal + } + defer rows.Close() + + domains := []domain{} + for rows.Next() { + d := domain{} + if err = rows.Scan(&d.Domain, &d.OwnerHex, &d.Name, &d.CreationDate, &d.State, &d.ImportedComments, &d.AutoSpamFilter, &d.RequireModeration, &d.RequireIdentification); err != nil { + logger.Errorf("cannot Scan domain: %v", err) + return nil, errorInternal + } + + d.Moderators, err = domainModeratorList(d.Domain) + if err != nil { + return []domain{}, err + } + + domains = append(domains, d) + } + + return domains, rows.Err() +} + +func domainListHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + o, err := ownerGetBySession(*x.Session) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + domains, err := domainList(o.OwnerHex) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true, "domains": domains}) +} diff --git a/api/domain_list_test.go b/api/domain_list_test.go new file mode 100644 index 0000000..d0e5175 --- /dev/null +++ b/api/domain_list_test.go @@ -0,0 +1,33 @@ +package main + +import ( + "testing" +) + +func TestDomainListBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + domainNew("temp-owner-hex", "Example", "example.com") + domainNew("temp-owner-hex", "Example", "example2.com") + + d, err := domainList("temp-owner-hex") + if err != nil { + t.Errorf("unexpected error listing domains: %v", err) + return + } + + if len(d) != 2 { + t.Errorf("expected number of domains to be 2 got %d", len(d)) + return + } + + if d[0].Domain != "example.com" { + t.Errorf("expected first domain to be example.com got %s", d[0].Domain) + return + } + + if d[1].Domain != "example2.com" { + t.Errorf("expected first domain to be example2.com got %s", d[1].Domain) + return + } +} diff --git a/api/domain_moderator.go b/api/domain_moderator.go new file mode 100644 index 0000000..265c931 --- /dev/null +++ b/api/domain_moderator.go @@ -0,0 +1,57 @@ +package main + +import ( + "time" +) + +type moderator struct { + Email string `json:"email"` + Domain string `json:"domain"` + AddDate time.Time `json:"addDate"` +} + +func domainModeratorList(domain string) ([]moderator, error) { + statement := ` + SELECT email, addDate + FROM moderators + WHERE domain=$1; + ` + rows, err := db.Query(statement, domain) + if err != nil { + logger.Errorf("cannot get moderators: %v", err) + return nil, errorInternal + } + defer rows.Close() + + moderators := []moderator{} + for rows.Next() { + m := moderator{} + if err = rows.Scan(&m.Email, &m.AddDate); err != nil { + logger.Errorf("cannot Scan moderator: %v", err) + return nil, errorInternal + } + + moderators = append(moderators, m) + } + + return moderators, nil +} + +func isDomainModerator(domain string, email string) (bool, error) { + statement := ` + SELECT EXISTS ( + SELECT 1 + FROM moderators + WHERE domain=$1 AND email=$2 + ); + ` + row := db.QueryRow(statement, domain, email) + + var exists bool + if err := row.Scan(&exists); err != nil { + logger.Errorf("cannot query if moderator: %v", err) + return false, errorInternal + } + + return exists, nil +} diff --git a/api/domain_moderator_delete.go b/api/domain_moderator_delete.go new file mode 100644 index 0000000..9589513 --- /dev/null +++ b/api/domain_moderator_delete.go @@ -0,0 +1,62 @@ +package main + +import ( + "net/http" +) + +func domainModeratorDelete(domain string, email string) error { + if domain == "" || email == "" { + return errorMissingConfig + } + + statement := ` + DELETE FROM moderators + WHERE domain=$1 AND email=$2; + ` + _, err := db.Exec(statement, domain, email) + if err != nil { + logger.Errorf("cannot delete moderator: %v", err) + return errorInternal + } + + return nil +} + +func domainModeratorDeleteHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + Domain *string `json:"domain"` + Email *string `json:"email"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + o, err := ownerGetBySession(*x.Session) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + domain := stripDomain(*x.Domain) + authorised, err := domainOwnershipVerify(domain, o.Email) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if !authorised { + writeBody(w, response{"success": false, "message": errorNotAuthorised.Error()}) + return + } + + if err = domainModeratorDelete(domain, *x.Email); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true}) +} diff --git a/api/domain_moderator_delete_test.go b/api/domain_moderator_delete_test.go new file mode 100644 index 0000000..d49471d --- /dev/null +++ b/api/domain_moderator_delete_test.go @@ -0,0 +1,45 @@ +package main + +import ( + "testing" +) + +func TestDomainModeratorDeleteBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + domainModeratorNew("example.com", "test@example.com") + domainModeratorNew("example.com", "test2@example.com") + + if err := domainModeratorDelete("example.com", "test@example.com"); err != nil { + t.Errorf("unexpected error creating new domain moderator: %v", err) + return + } + + isMod, _ := isDomainModerator("example.com", "test@example.com") + if isMod { + t.Errorf("email %s still moderator after deletion", "test@example.com") + return + } + + isMod, _ = isDomainModerator("example.com", "test2@example.com") + if !isMod { + t.Errorf("email %s no longer moderator after deleting a different email", "test@example.com") + return + } +} + +func TestDomainModeratorDeleteEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + domainModeratorNew("example.com", "test@example.com") + + if err := domainModeratorDelete("example.com", ""); err == nil { + t.Errorf("expected error not found when passing empty email") + return + } + + if err := domainModeratorDelete("", ""); err == nil { + t.Errorf("expected error not found when passing empty everything") + return + } +} diff --git a/api/domain_moderator_new.go b/api/domain_moderator_new.go new file mode 100644 index 0000000..2ce2d6b --- /dev/null +++ b/api/domain_moderator_new.go @@ -0,0 +1,64 @@ +package main + +import ( + "net/http" + "time" +) + +func domainModeratorNew(domain string, email string) error { + if domain == "" || email == "" { + return errorMissingField + } + + statement := ` + INSERT INTO + moderators (domain, email, addDate) + VALUES ($1, $2, $3 ); + ` + _, err := db.Exec(statement, domain, email, time.Now().UTC()) + if err != nil { + logger.Errorf("cannot insert new moderator: %v", err) + return errorInternal + } + + return nil +} + +func domainModeratorNewHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + Domain *string `json:"domain"` + Email *string `json:"email"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + o, err := ownerGetBySession(*x.Session) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + domain := stripDomain(*x.Domain) + isOwner, err := domainOwnershipVerify(o.OwnerHex, domain) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if !isOwner { + writeBody(w, response{"success": false, "message": errorNotAuthorised.Error()}) + return + } + + if err = domainModeratorNew(domain, *x.Email); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true}) +} diff --git a/api/domain_moderator_new_test.go b/api/domain_moderator_new_test.go new file mode 100644 index 0000000..1654f8d --- /dev/null +++ b/api/domain_moderator_new_test.go @@ -0,0 +1,28 @@ +package main + +import ( + "testing" +) + +func TestDomainModeratorNewBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if err := domainModeratorNew("example.com", "test@example.com"); err != nil { + t.Errorf("unexpected error creating new domain moderator: %v", err) + return + } +} + +func TestDomainModeratorNewEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if err := domainModeratorNew("example.com", ""); err == nil { + t.Errorf("expected error not found when creating new moderator with empty email") + return + } + + if err := domainModeratorNew("", "test@example.com"); err == nil { + t.Errorf("expected error not found when creating new moderator with empty domain") + return + } +} diff --git a/api/domain_moderator_test.go b/api/domain_moderator_test.go new file mode 100644 index 0000000..69a61e8 --- /dev/null +++ b/api/domain_moderator_test.go @@ -0,0 +1,61 @@ +package main + +import ( + "testing" +) + +func TestDomainModeratorListBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + domainModeratorNew("example.com", "test@example.com") + domainModeratorNew("example.com", "test2@example.com") + + mods, err := domainModeratorList("example.com") + if err != nil { + t.Errorf("unexpected error listing domain moderators: %v", err) + return + } + + if len(mods) != 2 { + t.Errorf("expected number of domain moderators to be 2 got %d", len(mods)) + return + } + + if mods[0].Email != "test@example.com" { + t.Errorf("expected first domain to be test@example.com got %s", mods[0].Email) + return + } + + if mods[1].Email != "test2@example.com" { + t.Errorf("expected first domain to be test2@example.com got %s", mods[0].Email) + return + } +} + +func TestIsDomainModeratorBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + domainModeratorNew("example.com", "test@example.com") + + isMod, err := isDomainModerator("example.com", "test@example.com") + if err != nil { + t.Errorf("unexpected error checking if email is a moderator: %v", err) + return + } + + if !isMod { + t.Errorf("expected test@example.com to be a moderator got isMod=false") + return + } + + isMod, err = isDomainModerator("example.com", "test2@example.com") + if err != nil { + t.Errorf("unexpected error checking if email is a moderator: %v", err) + return + } + + if isMod { + t.Errorf("expected test2@example.com to not be a moderator got isMod=true") + return + } +} diff --git a/api/domain_new.go b/api/domain_new.go new file mode 100644 index 0000000..d7c8a71 --- /dev/null +++ b/api/domain_new.go @@ -0,0 +1,59 @@ +package main + +import ( + "net/http" + "time" +) + +func domainNew(ownerHex string, name string, domain string) error { + if ownerHex == "" || name == "" || domain == "" { + return errorMissingField + } + + statement := ` + INSERT INTO + domains (ownerHex, name, domain, creationDate) + VALUES ($1, $2, $3, $4 ); + ` + _, err := db.Exec(statement, ownerHex, name, domain, time.Now().UTC()) + if err != nil { + // TODO: Make sure this is really the error. + return errorDomainAlreadyExists + } + + return nil +} + +func domainNewHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + Name *string `json:"name"` + Domain *string `json:"domain"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + o, err := ownerGetBySession(*x.Session) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + domain := stripDomain(*x.Domain) + + if err = domainNew(o.Email, *x.Name, domain); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if err = domainModeratorNew(*x.Domain, o.Email); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true, "domain": domain}) +} diff --git a/api/domain_new_test.go b/api/domain_new_test.go new file mode 100644 index 0000000..8df2258 --- /dev/null +++ b/api/domain_new_test.go @@ -0,0 +1,42 @@ +package main + +import ( + "testing" +) + +func TestDomainNewBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if err := domainNew("temp-owner-hex", "Example", "example.com"); err != nil { + t.Errorf("unexpected error creating domain: %v", err) + return + } +} + +func TestDomainNewClash(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if err := domainNew("temp-owner-hex", "Example", "example.com"); err != nil { + t.Errorf("unexpected error creating domain: %v", err) + return + } + + if err := domainNew("temp-owner-hex", "Example 2", "example.com"); err == nil { + t.Errorf("expected error not found when creating with clashing domain") + return + } +} + +func TestDomainNewEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if err := domainNew("temp-owner-hex", "Example", ""); err == nil { + t.Errorf("expected error not found when creating with emtpy domain") + return + } + + if err := domainNew("", "", ""); err == nil { + t.Errorf("expected error not found when creating with emtpy everything") + return + } +} diff --git a/api/domain_ownership_verify.go b/api/domain_ownership_verify.go new file mode 100644 index 0000000..fc6380a --- /dev/null +++ b/api/domain_ownership_verify.go @@ -0,0 +1,26 @@ +package main + +import () + +func domainOwnershipVerify(ownerHex string, domain string) (bool, error) { + if ownerHex == "" || domain == "" { + return false, errorMissingField + } + + statement := ` + SELECT EXISTS ( + SELECT 1 + FROM domains + WHERE ownerHex=$1 AND domain=$2 + ); + ` + row := db.QueryRow(statement, ownerHex, domain) + + var exists bool + if err := row.Scan(&exists); err != nil { + logger.Errorf("cannot query if domain owner: %v", err) + return false, errorInternal + } + + return exists, nil +} diff --git a/api/domain_ownership_verify_test.go b/api/domain_ownership_verify_test.go new file mode 100644 index 0000000..00614b9 --- /dev/null +++ b/api/domain_ownership_verify_test.go @@ -0,0 +1,39 @@ +package main + +import ( + "testing" +) + +func TestDomainVerifyOwnershipBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + ownerHex, _ := ownerNew("test@example.com", "Test", "hunter2") + ownerLogin("test@example.com", "hunter2") + + domainNew(ownerHex, "Example", "example.com") + + isOwner, err := domainOwnershipVerify(ownerHex, "example.com") + if err != nil { + t.Errorf("error checking ownership: %v", err) + return + } + + if !isOwner { + t.Errorf("expected isOwner=true got isOwner=false") + return + } + + otherOwnerHex, _ := ownerNew("test2@example.com", "Test2", "hunter2") + ownerLogin("test2@example.com", "hunter2") + + isOwner, err = domainOwnershipVerify(otherOwnerHex, "example.com") + if err != nil { + t.Errorf("error checking ownership: %v", err) + return + } + + if isOwner { + t.Errorf("expected isOwner=false got isOwner=true") + return + } +} diff --git a/api/domain_update.go b/api/domain_update.go new file mode 100644 index 0000000..0517478 --- /dev/null +++ b/api/domain_update.go @@ -0,0 +1,59 @@ +package main + +import ( + "net/http" +) + +func domainUpdate(d domain) error { + statement := ` + UPDATE domains + SET name=$2, state=$3, autoSpamFilter=$4, requireModeration=$5, requireIdentification=$6 + WHERE domain=$1; + ` + + _, err := db.Exec(statement, d.Domain, d.Name, d.State, d.AutoSpamFilter, d.RequireModeration, d.RequireIdentification) + if err != nil { + logger.Errorf("cannot update non-moderators: %v", err) + return errorInternal + } + + return nil +} + +func domainUpdateHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + D *domain `json:"domain"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + o, err := ownerGetBySession(*x.Session) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + domain := stripDomain((*x.D).Domain) + isOwner, err := domainOwnershipVerify(o.OwnerHex, domain) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if !isOwner { + writeBody(w, response{"success": false, "message": errorNotAuthorised.Error()}) + return + } + + if err = domainUpdate(*x.D); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true}) +} diff --git a/api/domain_update_test.go b/api/domain_update_test.go new file mode 100644 index 0000000..5a8d4e8 --- /dev/null +++ b/api/domain_update_test.go @@ -0,0 +1,27 @@ +package main + +import ( + "testing" +) + +func TestDomainUpdateBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + domainNew("temp-owner-hex", "Example", "example.com") + + d, _ := domainList("temp-owner-hex") + + d[0].Name = "Example2" + + if err := domainUpdate(d[0]); err != nil { + t.Errorf("unexpected error updating domain: %v", err) + return + } + + d, _ = domainList("temp-owner-hex") + + if d[0].Name != "Example2" { + t.Errorf("expected name=Example2 got name=%s", d[0].Name) + return + } +} diff --git a/api/domain_view_record.go b/api/domain_view_record.go new file mode 100644 index 0000000..b3a5572 --- /dev/null +++ b/api/domain_view_record.go @@ -0,0 +1,17 @@ +package main + +import ( + "time" +) + +func domainViewRecord(domain string, commenterHex string) { + statement := ` + INSERT INTO + views (domain, commenterHex, viewDate) + VALUES ($1, $2, $3 ); + ` + _, err := db.Exec(statement, domain, commenterHex, time.Now().UTC()) + if err != nil { + logger.Warningf("cannot insert views: %v", err) + } +} diff --git a/api/errors.go b/api/errors.go new file mode 100644 index 0000000..d9a8d32 --- /dev/null +++ b/api/errors.go @@ -0,0 +1,44 @@ +package main + +import ( + "errors" +) + +var errorMalformedTemplate = errors.New("A template is malformed.") +var errorMissingConfig = errors.New("Missing config environment variable.") +var errorCannotSendEmail = errors.New("Email dispatch failed. Please contact support to resolve this issue.") +var errorInternal = errors.New("An internal error has occurred. If you see this repeatedly, please contact support.") +var errorInvalidJSONBody = errors.New("Invalid JSON request. If you think this shouldn't happen, please contact support.") +var errorMissingField = errors.New("One or more field(s) empty.") +var errorEmailExists = errors.New("That email address is already registered. Sign in instead?") +var errorInvalidEmailPassword = errors.New("Invalid email/password combination.") +var errorUnconfirmedEmail = errors.New("Your email address is still unconfirmed. Please confirm your email address before proceeding.") +var errorNoSuchConfirmationToken = errors.New("This email confirmation link has expired.") +var errorNoSuchResetToken = errors.New("This password reset link has expired.") +var errorNotAuthorised = errors.New("You're not authorised to access that.") +var errorEmailAlreadyExists = errors.New("That email address has already been registered.") +var errorNoSuchSession = errors.New("No such session/state.") +var errorAlreadyUpvoted = errors.New("You have already upvoted that comment.") +var errorNoSuchDomain = errors.New("This domain is not registered with Commento.") +var errorNoSuchComment = errors.New("No such comment.") +var errorNeedPro = errors.New("You need to have a pro/business account to do that.") +var errorInvalidState = errors.New("Invalid state value.") +var errorInvalidTrial = errors.New("Invalid trial value.") +var errorDomainFrozen = errors.New("Cannot add a new comment because that domain is frozen.") +var errorDomainAlreadyExists = errors.New("That domain has already been registered. Please contact support if you are the true owner.") +var errorUnauthorisedVote = errors.New("You need to be authenticated to vote.") +var errorNoSuchEmail = errors.New("No such email.") +var errorInvalidEmail = errors.New("You do not have an email registered with that account.") +var errorNoTrialChange = errors.New("You cannot change to a trial plan.") +var errorInvalidPlan = errors.New("Invalid plan value.") +var errorNoSource = errors.New("You have no payment source on record to change your plan.") +var errorCannotDowngrage = errors.New("Cannot downgrade plan features.") +var errorForbiddenEdit = errors.New("You cannot edit someone else's comment.") +var errorNotInvited = errors.New("Commento is currently in private beta and invite-only for now.") +var errorMissingSmtpAddress = errors.New("Missing SMTP_FROM_ADDRESS") +var errorSmtpNotConfigured = errors.New("SMTP is not configured.") +var errorOauthMisconfigured = errors.New("OAuth is misconfigured.") +var errorUnassociatedSession = errors.New("No user associated with that session.") +var errorSessionAlreadyInUse = errors.New("Session is already in use.") +var errorCannotReadResponse = errors.New("Cannot read response.") +var errorNotModerator = errors.New("You need to be a moderator to do that.") diff --git a/api/main.go b/api/main.go new file mode 100644 index 0000000..a395a2c --- /dev/null +++ b/api/main.go @@ -0,0 +1,13 @@ +package main + +func main() { + exitIfError(createLogger()) + exitIfError(parseConfig()) + exitIfError(connectDB()) + exitIfError(performMigrations()) + exitIfError(smtpConfigure()) + exitIfError(oauthConfigure()) + exitIfError(createMarkdownRenderer()) + + exitIfError(serveRoutes()) +} diff --git a/api/markdown.go b/api/markdown.go new file mode 100644 index 0000000..5c96693 --- /dev/null +++ b/api/markdown.go @@ -0,0 +1,28 @@ +package main + +import ( + "github.com/microcosm-cc/bluemonday" + "gopkg.in/russross/blackfriday.v1" +) + +var policy *bluemonday.Policy +var renderer blackfriday.Renderer +var extensions int + +func createMarkdownRenderer() error { + policy = bluemonday.UGCPolicy() + + extensions = 0 + extensions |= blackfriday.EXTENSION_AUTOLINK + extensions |= blackfriday.EXTENSION_STRIKETHROUGH + + htmlFlags := 0 + htmlFlags |= blackfriday.HTML_SKIP_HTML + htmlFlags |= blackfriday.HTML_SKIP_IMAGES + htmlFlags |= blackfriday.HTML_SAFELINK + htmlFlags |= blackfriday.HTML_HREF_TARGET_BLANK + + renderer = blackfriday.HtmlRenderer(htmlFlags, "", "") + + return nil +} diff --git a/api/markdown_html.go b/api/markdown_html.go new file mode 100644 index 0000000..e799d18 --- /dev/null +++ b/api/markdown_html.go @@ -0,0 +1,10 @@ +package main + +import ( + "gopkg.in/russross/blackfriday.v1" +) + +func markdownToHtml(markdown string) string { + unsafe := blackfriday.Markdown([]byte(markdown), renderer, extensions) + return string(policy.SanitizeBytes(unsafe)) +} diff --git a/api/markdown_html_test.go b/api/markdown_html_test.go new file mode 100644 index 0000000..6e03cf0 --- /dev/null +++ b/api/markdown_html_test.go @@ -0,0 +1,39 @@ +package main + +import ( + "strings" + "testing" +) + +func TestMarkdownToHtmlBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + // basic markdown and expected html tests + tests := map[string]string{ + "Foo": "

Foo

", + + "Foo\n\nBar": "

Foo

\n\n

Bar

", + + "XSS: Foo": "

XSS: Foo

", + + "Regular [Link](http://example.com)": "

Regular Link

", + + "XSS [Link](data:text/html;base64,PHNjcmlwdD5hbGVydCgxKTwvc2NyaXB0Pgo=)": "

XSS Link

", + + "![Images disallowed](http://example.com/image.jpg)": "

", + + "**bold** *italics*": "

bold italics

", + + "http://example.com/autolink": "

http://example.com/autolink

", + + "not bold": "

not bold

", + } + + for in, out := range tests { + html := strings.TrimSpace(markdownToHtml(in)) + if html != out { + t.Errorf("for in=[%s] expected out=[%s] got out=[%s]", in, out, html) + return + } + } +} diff --git a/api/oauth.go b/api/oauth.go new file mode 100644 index 0000000..c9c7fc3 --- /dev/null +++ b/api/oauth.go @@ -0,0 +1,11 @@ +package main + +import () + +func oauthConfigure() error { + if err := googleOauthConfigure(); err != nil { + return err + } + + return nil +} diff --git a/api/oauth_google.go b/api/oauth_google.go new file mode 100644 index 0000000..3a7a519 --- /dev/null +++ b/api/oauth_google.go @@ -0,0 +1,41 @@ +package main + +import ( + "golang.org/x/oauth2" + "golang.org/x/oauth2/google" + "os" +) + +var googleConfig *oauth2.Config + +func googleOauthConfigure() error { + googleConfig = nil + if os.Getenv("GOOGLE_KEY") == "" && os.Getenv("GOOGLE_SECRET") == "" { + return nil + } + + if os.Getenv("GOOGLE_KEY") == "" { + logger.Errorf("GOOGLE_KEY not configured, but GOOGLE_SECRET is set") + return errorOauthMisconfigured + } + + if os.Getenv("GOOGLE_SECRET") == "" { + logger.Errorf("GOOGLE_SECRET not configured, but GOOGLE_KEY is set") + return errorOauthMisconfigured + } + + logger.Infof("loading Google OAuth config") + + googleConfig = &oauth2.Config{ + RedirectURL: os.Getenv("BACKEND_WEB") + "/oauth/google/callback", + ClientID: os.Getenv("GOOGLE_KEY"), + ClientSecret: os.Getenv("GOOGLE_SECRET"), + Scopes: []string{ + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/userinfo.email", + }, + Endpoint: google.Endpoint, + } + + return nil +} diff --git a/api/oauth_google_callback.go b/api/oauth_google_callback.go new file mode 100644 index 0000000..751a39a --- /dev/null +++ b/api/oauth_google_callback.go @@ -0,0 +1,85 @@ +package main + +import ( + "encoding/json" + "fmt" + "golang.org/x/oauth2" + "io/ioutil" + "net/http" +) + +func googleCallbackHandler(w http.ResponseWriter, r *http.Request) { + session := r.FormValue("state") + code := r.FormValue("code") + + cs, err := commenterSessionGet(session) + if err != nil { + fmt.Fprintf(w, "Error: %s\n", err.Error()) + return + } + + if cs.Session != "none" { + fmt.Fprintf(w, "Error: %v", errorSessionAlreadyInUse.Error()) + return + } + + token, err := googleConfig.Exchange(oauth2.NoContext, code) + if err != nil { + fmt.Fprintf(w, "Error: %s", err.Error()) + return + } + + resp, err := http.Get("https://www.googleapis.com/oauth2/v2/userinfo?access_token=" + token.AccessToken) + defer resp.Body.Close() + + contents, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Fprintf(w, "Error: %s", errorCannotReadResponse.Error()) + return + } + + user := make(map[string]interface{}) + if err := json.Unmarshal(contents, &user); err != nil { + fmt.Fprintf(w, "Error: %s", errorInternal.Error()) + return + } + + exists, err := commenterIsProviderUser("google", user["email"].(string)) + if err != nil { + fmt.Fprintf(w, "Error: %s", err.Error()) + return + } + + var commenterHex string + + // TODO: in case of returning users, update the information we have on record? + if !exists { + var email string + if _, ok := user["email"]; ok { + email = user["email"].(string) + } else { + fmt.Fprintf(w, "error: %s", errorInvalidEmail.Error()) + return + } + + var link string + if val, ok := user["link"]; ok { + link = val.(string) + } else { + link = "undefined" + } + + commenterHex, err = commenterNew(email, user["name"].(string), link, user["picture"].(string), "google") + if err != nil { + fmt.Fprintf(w, "Error: %s", err.Error()) + return + } + } + + if err := commenterSessionUpdate(session, commenterHex); err != nil { + fmt.Fprintf(w, "Error: %s", err.Error()) + return + } + + fmt.Fprintf(w, "") +} diff --git a/api/oauth_google_redirect.go b/api/oauth_google_redirect.go new file mode 100644 index 0000000..9767324 --- /dev/null +++ b/api/oauth_google_redirect.go @@ -0,0 +1,24 @@ +package main + +import ( + "fmt" + "net/http" +) + +func googleRedirectHandler(w http.ResponseWriter, r *http.Request) { + session := r.FormValue("session") + + c, err := commenterGetBySession(session) + if err != nil { + fmt.Fprintf(w, "error: %s\n", err.Error()) + return + } + + if c.CommenterHex != "none" { + fmt.Fprintf(w, "error: that session is already in use\n") + return + } + + url := googleConfig.AuthCodeURL(session) + http.Redirect(w, r, url, http.StatusFound) +} diff --git a/api/oauth_google_test.go b/api/oauth_google_test.go new file mode 100644 index 0000000..2cf2852 --- /dev/null +++ b/api/oauth_google_test.go @@ -0,0 +1,59 @@ +package main + +import ( + "os" + "testing" +) + +func resetGoogleVars() { + for _, env := range []string{"GOOGLE_KEY", "GOOGLE_SECRET"} { + os.Setenv(env, "") + } +} + +func TestGoogleOauthConfigureBasics(t *testing.T) { + resetGoogleVars() + + os.Setenv("GOOGLE_KEY", "google-key") + os.Setenv("GOOGLE_SECRET", "google-secret") + + if err := googleOauthConfigure(); err != nil { + t.Errorf("unexpected error configuring google oauth: %v", err) + return + } + + if googleConfig == nil { + t.Errorf("expected googleConfig!=nil got googleConfig=nil") + return + } +} + +func TestGoogleOauthConfigureEmpty(t *testing.T) { + resetGoogleVars() + + os.Setenv("GOOGLE_KEY", "google-key") + + if err := googleOauthConfigure(); err == nil { + t.Errorf("expected error not found when configuring google oauth with empty GOOGLE_SECRET") + return + } + + if googleConfig != nil { + t.Errorf("expected googleConfig=nil got googleConfig=%v", googleConfig) + return + } +} + +func TestGoogleOauthConfigureEmpty2(t *testing.T) { + resetGoogleVars() + + if err := googleOauthConfigure(); err != nil { + t.Errorf("unexpected error configuring google oauth with empty everything: should be disabled") + return + } + + if googleConfig != nil { + t.Errorf("expected googleConfig=nil got googleConfig=%v", googleConfig) + return + } +} diff --git a/api/owner.go b/api/owner.go new file mode 100644 index 0000000..b892f3e --- /dev/null +++ b/api/owner.go @@ -0,0 +1,13 @@ +package main + +import ( + "time" +) + +type owner struct { + OwnerHex string `json:"ownerHex"` + Email string `json:"email"` + Name string `json:"name"` + ConfirmedEmail bool `json:"confirmedEmail"` + JoinDate time.Time `json:"joinDate"` +} diff --git a/api/owner_confirm_hex.go b/api/owner_confirm_hex.go new file mode 100644 index 0000000..564a8bc --- /dev/null +++ b/api/owner_confirm_hex.go @@ -0,0 +1,61 @@ +package main + +import ( + "fmt" + "net/http" + "os" +) + +func ownerConfirmHex(confirmHex string) error { + if confirmHex == "" { + return errorMissingField + } + + statement := ` + UPDATE owners + SET confirmedEmail=true + WHERE ownerHex IN ( + SELECT ownerHex FROM ownerConfirmHexes + WHERE confirmHex=$1 + ); + ` + res, err := db.Exec(statement, confirmHex) + if err != nil { + logger.Errorf("cannot mark user's confirmedEmail as true: %v\n", err) + return errorInternal + } + + count, err := res.RowsAffected() + if err != nil { + logger.Errorf("cannot count rows affected: %v\n", err) + return errorInternal + } + + if count == 0 { + return errorNoSuchConfirmationToken + } + + statement = ` + DELETE FROM ownerConfirmHexes + WHERE confirmHex=$1; + ` + _, err = db.Exec(statement, confirmHex) + if err != nil { + logger.Warningf("cannot remove confirmation token: %v\n", err) + // Don't return an error because this is not critical. + } + + return nil +} + +func ownerConfirmHexHandler(w http.ResponseWriter, r *http.Request) { + if confirmHex := r.FormValue("confirmHex"); confirmHex != "" { + if err := ownerConfirmHex(confirmHex); err == nil { + http.Redirect(w, r, fmt.Sprintf("%s/login?confirmed=true", os.Getenv("FRONTEND")), http.StatusTemporaryRedirect) + return + } + } + + // TODO: include error message in the URL + http.Redirect(w, r, fmt.Sprintf("%s/login?confirmed=false", os.Getenv("FRONTEND")), http.StatusTemporaryRedirect) +} diff --git a/api/owner_confirm_hex_test.go b/api/owner_confirm_hex_test.go new file mode 100644 index 0000000..ed926e9 --- /dev/null +++ b/api/owner_confirm_hex_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "testing" + "time" +) + +func TestOwnerConfirmHexBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + ownerHex, _ := ownerNew("test@example.com", "Test", "hunter2") + + statement := ` + UPDATE owners + SET confirmedEmail=false; + ` + _, err := db.Exec(statement) + if err != nil { + t.Errorf("unexpected error when setting confirmedEmail=false: %v", err) + return + } + + confirmHex, _ := randomHex(32) + + statement = ` + INSERT INTO + ownerConfirmHexes (confirmHex, ownerHex, sendDate) + VALUES ($1, $2, $3 ); + ` + _, err = db.Exec(statement, confirmHex, ownerHex, time.Now().UTC()) + if err != nil { + t.Errorf("unexpected error creating inserting confirmHex: %v\n", err) + return + } + + if err = ownerConfirmHex(confirmHex); err != nil { + t.Errorf("unexpected error confirming hex: %v", err) + return + } + + statement = ` + SELECT confirmedEmail + FROM owners + WHERE ownerHex=$1; + ` + row := db.QueryRow(statement, ownerHex) + + var confirmedHex bool + if err = row.Scan(&confirmedHex); err != nil { + t.Errorf("unexpected error scanning confirmedEmail: %v", err) + return + } + + if !confirmedHex { + t.Errorf("confirmedHex expected to be true after confirmation; found to be false") + return + } +} diff --git a/api/owner_get.go b/api/owner_get.go new file mode 100644 index 0000000..607a621 --- /dev/null +++ b/api/owner_get.go @@ -0,0 +1,48 @@ +package main + +import () + +func ownerGetByEmail(email string) (owner, error) { + if email == "" { + return owner{}, errorMissingField + } + + statement := ` + SELECT ownerHex, email, name, confirmedEmail, joinDate + FROM owners + WHERE email=$1; + ` + row := db.QueryRow(statement, email) + + var o owner + if err := row.Scan(&o.OwnerHex, &o.Email, &o.Name, &o.ConfirmedEmail, &o.JoinDate); err != nil { + // TODO: Make sure this is actually no such email. + return owner{}, errorNoSuchEmail + } + + return o, nil +} + +func ownerGetBySession(session string) (owner, error) { + if session == "" { + return owner{}, errorMissingField + } + + statement := ` + SELECT ownerHex, email, name, confirmedEmail, joinDate + FROM owners + WHERE email IN ( + SELECT email FROM ownerSessions + WHERE session=$1 + ); + ` + row := db.QueryRow(statement, session) + + var o owner + if err := row.Scan(&o.OwnerHex, &o.Email, &o.Name, &o.ConfirmedEmail, &o.JoinDate); err != nil { + logger.Errorf("cannot scan owner: %v\n", err) + return owner{}, errorInternal + } + + return o, nil +} diff --git a/api/owner_get_test.go b/api/owner_get_test.go new file mode 100644 index 0000000..1972e3c --- /dev/null +++ b/api/owner_get_test.go @@ -0,0 +1,59 @@ +package main + +import ( + "testing" +) + +func TestOwnerGetByEmailBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + ownerHex, _ := ownerNew("test@example.com", "Test", "hunter2") + + o, err := ownerGetByEmail("test@example.com") + if err != nil { + t.Errorf("unexpected error on ownerGetByEmail: %v", err) + return + } + + if o.OwnerHex != ownerHex { + t.Errorf("expected ownerHex=%s got ownerHex=%s", ownerHex, o.OwnerHex) + return + } +} + +func TestOwnerGetByEmailDNE(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := ownerGetByEmail("invalid@example.com"); err == nil { + t.Errorf("expected error not found on ownerGetByEmail before creating an account") + return + } +} + +func TestOwnerGetBySessionBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + ownerHex, _ := ownerNew("test@example.com", "Test", "hunter2") + + session, _ := ownerLogin("test@example.com", "hunter2") + + o, err := ownerGetBySession(session) + if err != nil { + t.Errorf("unexpected error on ownerGetBySession: %v", err) + return + } + + if o.OwnerHex != ownerHex { + t.Errorf("expected ownerHex=%s got ownerHex=%s", ownerHex, o.OwnerHex) + return + } +} + +func TestOwnerGetBySessionDNE(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := ownerGetBySession("does-not-exist"); err == nil { + t.Errorf("expected error not found on ownerGetBySession before creating an account") + return + } +} diff --git a/api/owner_login.go b/api/owner_login.go new file mode 100644 index 0000000..218c1ae --- /dev/null +++ b/api/owner_login.go @@ -0,0 +1,76 @@ +package main + +import ( + "golang.org/x/crypto/bcrypt" + "net/http" + "time" +) + +func ownerLogin(email string, password string) (string, error) { + if email == "" || password == "" { + return "", errorMissingField + } + + statement := ` + SELECT ownerHex, confirmedEmail, passwordHash + FROM owners + WHERE email=$1; + ` + row := db.QueryRow(statement, email) + + var ownerHex string + var confirmedEmail bool + var passwordHash string + if err := row.Scan(&ownerHex, &confirmedEmail, &passwordHash); err != nil { + return "", errorInvalidEmailPassword + } + + if !confirmedEmail { + return "", errorUnconfirmedEmail + } + + if err := bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(password)); err != nil { + // TODO: is this the only possible error? + return "", errorInvalidEmailPassword + } + + session, err := randomHex(32) + if err != nil { + logger.Errorf("cannot create session hex: %v", err) + return "", errorInternal + } + + statement = ` + INSERT INTO + ownerSessions (session, ownerHex, loginDate) + VALUES ($1, $2, $3 ); + ` + _, err = db.Exec(statement, session, ownerHex, time.Now().UTC()) + if err != nil { + logger.Errorf("cannot insert session token: %v\n", err) + return "", errorInternal + } + + return session, nil +} + +func ownerLoginHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Email *string `json:"email"` + Password *string `json:"password"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + session, err := ownerLogin(*x.Email, *x.Password) + if err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true, "session": session}) +} diff --git a/api/owner_login_test.go b/api/owner_login_test.go new file mode 100644 index 0000000..282c52e --- /dev/null +++ b/api/owner_login_test.go @@ -0,0 +1,47 @@ +package main + +import ( + "testing" +) + +func TestOwnerLoginBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := ownerLogin("test@example.com", "hunter2"); err == nil { + t.Errorf("expected error not found when logging in without creating an account") + return + } + + ownerNew("test@example.com", "Test", "hunter2") + + if _, err := ownerLogin("test@example.com", "hunter2"); err != nil { + t.Errorf("unexpected error when logging in: %v", err) + return + } + + if _, err := ownerLogin("test@example.com", "h******"); err == nil { + t.Errorf("expected error not found when given wrong password") + return + } + + if session, err := ownerLogin("test@example.com", "hunter2"); session == "" { + t.Errorf("empty session on successful login: %v", err) + return + } +} + +func TestOwnerLoginEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := ownerLogin("test@example.com", ""); err == nil { + t.Errorf("expected error not found when passing empty password") + return + } + + ownerNew("test@example.com", "Test", "hunter2") + + if _, err := ownerLogin("test@example.com", ""); err == nil { + t.Errorf("expected error not found when passing empty password") + return + } +} diff --git a/api/owner_new.go b/api/owner_new.go new file mode 100644 index 0000000..355df3f --- /dev/null +++ b/api/owner_new.go @@ -0,0 +1,83 @@ +package main + +import ( + "golang.org/x/crypto/bcrypt" + "net/http" + "time" +) + +func ownerNew(email string, name string, password string) (string, error) { + if email == "" || name == "" || password == "" { + return "", errorMissingField + } + + ownerHex, err := randomHex(32) + if err != nil { + logger.Errorf("cannot generate ownerHex: %v", err) + return "", errorInternal + } + + passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + logger.Errorf("cannot generate hash from password: %v\n", err) + return "", errorInternal + } + + statement := ` + INSERT INTO + owners (ownerHex, email, name, passwordHash, joinDate, confirmedEmail) + VALUES ($1, $2, $3, $4, $5, $6 ); + ` + _, err = db.Exec(statement, ownerHex, email, name, string(passwordHash), time.Now().UTC(), !smtpConfigured) + if err != nil { + // TODO: Make sure `err` is actually about conflicting UNIQUE, and not some + // other error. If it is something else, we should probably return `errorInternal`. + return "", errorEmailAlreadyExists + } + + if smtpConfigured { + confirmHex, err := randomHex(32) + if err != nil { + logger.Errorf("cannot generate confirmHex: %v", err) + return "", errorInternal + } + + statement = ` + INSERT INTO + ownerConfirmHexes (confirmHex, ownerHex, sendDate) + VALUES ($1, $2, $3 ); + ` + _, err = db.Exec(statement, confirmHex, ownerHex, time.Now().UTC()) + if err != nil { + logger.Errorf("cannot insert confirmHex: %v\n", err) + return "", errorInternal + } + + if err = smtpOwnerConfirmHex(email, name, confirmHex); err != nil { + return "", err + } + } + + return ownerHex, nil +} + +func ownerNewHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Email *string `json:"email"` + Name *string `json:"name"` + Password *string `json:"password"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if _, err := ownerNew(*x.Email, *x.Name, *x.Password); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true}) +} diff --git a/api/owner_new_test.go b/api/owner_new_test.go new file mode 100644 index 0000000..2c1bd42 --- /dev/null +++ b/api/owner_new_test.go @@ -0,0 +1,42 @@ +package main + +import ( + "testing" +) + +func TestOwnerNewBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := ownerNew("test@example.com", "Test", "hunter2"); err != nil { + t.Errorf("unexpected error when creating new owner: %v", err) + return + } +} + +func TestOwnerNewClash(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := ownerNew("test@example.com", "Test", "hunter2"); err != nil { + t.Errorf("unexpected error when creating new owner: %v", err) + return + } + + if _, err := ownerNew("test@example.com", "Test", "hunter2"); err == nil { + t.Errorf("expected error not found when creating with clashing email") + return + } +} + +func TestOwnerNewEmpty(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if _, err := ownerNew("test@example.com", "", "hunter2"); err == nil { + t.Errorf("expected error not found when passing empty name") + return + } + + if _, err := ownerNew("", "", ""); err == nil { + t.Errorf("expected error not found when passing empty everything") + return + } +} diff --git a/api/owner_reset_hex.go b/api/owner_reset_hex.go new file mode 100644 index 0000000..c2bf339 --- /dev/null +++ b/api/owner_reset_hex.go @@ -0,0 +1,70 @@ +package main + +import ( + "net/http" + "time" +) + +func ownerSendResetHex(email string) error { + if email == "" { + return errorMissingField + } + + o, err := ownerGetByEmail(email) + if err != nil { + if err == errorNoSuchEmail { + // TODO: use a more random time instead. + time.Sleep(1 * time.Second) + return nil + } else { + logger.Errorf("cannot get owner by email: %v", err) + return errorInternal + } + } + + if !smtpConfigured { + return errorSmtpNotConfigured + } + + resetHex, err := randomHex(32) + if err != nil { + return err + } + + statement := ` + INSERT INTO + ownerResetHexes (resetHex, ownerHex, sendDate) + VALUES ($1, $2, $3 ); + ` + _, err = db.Exec(statement, resetHex, o.OwnerHex, time.Now().UTC()) + if err != nil { + logger.Errorf("cannot insert resetHex: %v", err) + return errorInternal + } + + err = smtpOwnerResetHex(email, o.Name, resetHex) + if err != nil { + return err + } + + return nil +} + +func ownerSendResetHexHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Email *string `json:"email"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if err := ownerSendResetHex(*x.Email); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true}) +} diff --git a/api/owner_reset_password.go b/api/owner_reset_password.go new file mode 100644 index 0000000..7eaf92c --- /dev/null +++ b/api/owner_reset_password.go @@ -0,0 +1,72 @@ +package main + +import ( + "golang.org/x/crypto/bcrypt" + "net/http" +) + +func ownerResetPassword(resetHex string, password string) error { + if resetHex == "" || password == "" { + return errorMissingField + } + + passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + logger.Errorf("cannot generate hash from password: %v\n", err) + return errorInternal + } + + statement := ` + UPDATE owners SET passwordHash=$1 + WHERE email IN ( + SELECT email FROM ownerResetHexes + WHERE resetHex=$2 + ); + ` + res, err := db.Exec(statement, string(passwordHash), resetHex) + if err != nil { + logger.Errorf("cannot change user's password: %v\n", err) + return errorInternal + } + + count, err := res.RowsAffected() + if err != nil { + logger.Errorf("cannot count rows affected: %v\n", err) + return errorInternal + } + + if count == 0 { + return errorNoSuchResetToken + } + + statement = ` + DELETE FROM ownerResetHexes + WHERE resetHex=$1; + ` + _, err = db.Exec(statement, resetHex) + if err != nil { + logger.Warningf("cannot remove reset token: %v\n", err) + } + + return nil +} + +func ownerResetPasswordHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + ResetHex *string `json:"resetHex"` + Password *string `json:"password"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + if err := ownerResetPassword(*x.ResetHex, *x.Password); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + writeBody(w, response{"success": true}) +} diff --git a/api/owner_reset_password_test.go b/api/owner_reset_password_test.go new file mode 100644 index 0000000..8c11a12 --- /dev/null +++ b/api/owner_reset_password_test.go @@ -0,0 +1,40 @@ +package main + +import ( + "testing" + "time" +) + +func TestOwnerResetPasswordBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + ownerHex, _ := ownerNew("test@example.com", "Test", "hunter2") + + resetHex, _ := randomHex(32) + + statement := ` + INSERT INTO + ownerResetHexes (resetHex, ownerHex, sendDate) + VALUES ($1, $2, $3 ); + ` + _, err := db.Exec(statement, resetHex, ownerHex, time.Now().UTC()) + if err != nil { + t.Errorf("unexpected error inserting resetHex: %v", err) + return + } + + if err = ownerResetPassword(resetHex, "hunter3"); err != nil { + t.Errorf("unexpected error resetting password: %v", err) + return + } + + if _, err := ownerLogin("test@example.com", "hunter2"); err == nil { + t.Errorf("expected error not found when given old password") + return + } + + if _, err := ownerLogin("test@example.com", "hunter3"); err != nil { + t.Errorf("unexpected error when logging in: %v", err) + return + } +} diff --git a/api/owner_self.go b/api/owner_self.go new file mode 100644 index 0000000..5a5d629 --- /dev/null +++ b/api/owner_self.go @@ -0,0 +1,30 @@ +package main + +import ( + "net/http" +) + +func ownerSelf(session string) (bool, owner) { + o, err := ownerGetBySession(session) + if err != nil { + return false, owner{} + } + + return true, o +} + +func ownerSelfHandler(w http.ResponseWriter, r *http.Request) { + type request struct { + Session *string `json:"session"` + } + + var x request + if err := unmarshalBody(r, &x); err != nil { + writeBody(w, response{"success": false, "message": err.Error()}) + return + } + + loggedIn, o := ownerSelf(*x.Session) + + writeBody(w, response{"success": true, "loggedIn": loggedIn, "owner": o}) +} diff --git a/api/owner_self_test.go b/api/owner_self_test.go new file mode 100644 index 0000000..20c4dd7 --- /dev/null +++ b/api/owner_self_test.go @@ -0,0 +1,32 @@ +package main + +import ( + "testing" +) + +func TestOwnerSelfBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + ownerNew("test@example.com", "Test", "hunter2") + session, _ := ownerLogin("test@example.com", "hunter2") + + loggedIn, o := ownerSelf(session) + if !loggedIn { + t.Errorf("expected loggedIn=true got loggedIn=false") + return + } + + if o.Name != "Test" { + t.Errorf("expected name=Test got name=%s", o.Name) + return + } +} + +func TestOwnerSelfNotLoggedIn(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + if loggedIn, _ := ownerSelf("does-not-exist"); loggedIn { + t.Errorf("expected loggedIn=false got loggedIn=true") + return + } +} diff --git a/api/router.go b/api/router.go new file mode 100644 index 0000000..1781cd6 --- /dev/null +++ b/api/router.go @@ -0,0 +1,32 @@ +package main + +import ( + "github.com/gorilla/handlers" + "github.com/gorilla/mux" + "net/http" + "os" +) + +func serveRoutes() error { + router := mux.NewRouter() + + if err := initAPIRouter(router); err != nil { + return err + } + + if err := initStaticRouter(router); err != nil { + return err + } + + origins := handlers.AllowedOrigins([]string{"*"}) + headers := handlers.AllowedHeaders([]string{"X-Requested-With"}) + methods := handlers.AllowedMethods([]string{"GET", "POST"}) + + logger.Infof("starting server on port %s\n", os.Getenv("PORT")) + if err := http.ListenAndServe(":"+os.Getenv("PORT"), handlers.CORS(origins, headers, methods)(router)); err != nil { + logger.Errorf("cannot start server: %v", err) + return err + } + + return nil +} diff --git a/api/router_api.go b/api/router_api.go new file mode 100644 index 0000000..729c3c0 --- /dev/null +++ b/api/router_api.go @@ -0,0 +1,34 @@ +package main + +import ( + "github.com/gorilla/mux" +) + +func initAPIRouter(router *mux.Router) error { + router.HandleFunc("/api/owner/new", ownerNewHandler).Methods("POST") + router.HandleFunc("/api/owner/confirm-hex", ownerConfirmHexHandler).Methods("GET") + router.HandleFunc("/api/owner/login", ownerLoginHandler).Methods("POST") + router.HandleFunc("/api/owner/send-reset-hex", ownerSendResetHexHandler).Methods("POST") + router.HandleFunc("/api/owner/reset-password", ownerResetPasswordHandler).Methods("POST") + router.HandleFunc("/api/owner/self", ownerSelfHandler).Methods("POST") + + router.HandleFunc("/api/domain/new", domainNewHandler).Methods("POST") + router.HandleFunc("/api/domain/list", domainListHandler).Methods("POST") + router.HandleFunc("/api/domain/update", domainUpdateHandler).Methods("POST") + router.HandleFunc("/api/domain/moderator/new", domainModeratorNewHandler).Methods("POST") + router.HandleFunc("/api/domain/moderator/delete", domainModeratorDeleteHandler).Methods("POST") + + router.HandleFunc("/api/commenter/session/new", commenterSessionNewHandler).Methods("GET") + router.HandleFunc("/api/commenter/self", commenterSelfHandler).Methods("POST") + + router.HandleFunc("/api/oauth/google/redirect", googleRedirectHandler).Methods("GET") + router.HandleFunc("/api/oauth/google/callback", googleCallbackHandler).Methods("GET") + + router.HandleFunc("/api/comment/new", commentNewHandler).Methods("POST") + router.HandleFunc("/api/comment/list", commentListHandler).Methods("POST") + router.HandleFunc("/api/comment/vote", commentVoteHandler).Methods("POST") + router.HandleFunc("/api/comment/approve", commentApproveHandler).Methods("POST") + router.HandleFunc("/api/comment/delete", commentDeleteHandler).Methods("POST") + + return nil +} diff --git a/api/router_static.go b/api/router_static.go new file mode 100644 index 0000000..33a6eec --- /dev/null +++ b/api/router_static.go @@ -0,0 +1,69 @@ +package main + +import ( + "bytes" + "fmt" + "github.com/gorilla/mux" + "html/template" + "io/ioutil" + "net/http" + "os" +) + +func redirectLogin(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/login", 301) +} + +type staticHtmlPlugs struct { + CdnPrefix string +} + +func initStaticRouter(router *mux.Router) error { + for _, path := range []string{"js", "css", "images"} { + router.PathPrefix("/" + path + "/").HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + f, err := os.Stat("." + r.URL.Path) + if err != nil || f.IsDir() { + http.NotFound(w, r) + } + + http.ServeFile(w, r, "."+r.URL.Path) + }) + } + + pages := []string{ + "login", + "signup", + "dashboard", + "account", + } + + html := make(map[string]string) + for _, page := range pages { + contents, err := ioutil.ReadFile(page + ".html") + if err != nil { + logger.Errorf("cannot read file %s.html: %v", page, err) + return err + } + + t, err := template.New(page).Parse(string(contents)) + if err != nil { + logger.Errorf("cannot parse %s.html template: %v", page, err) + return err + } + + var buf bytes.Buffer + t.Execute(&buf, &staticHtmlPlugs{CdnPrefix: os.Getenv("CDN_PREFIX")}) + + html[page] = buf.String() + } + + for _, page := range pages { + router.HandleFunc("/"+page, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, html[page]) + }) + } + + router.HandleFunc("/", redirectLogin).Methods("GET") + + return nil +} diff --git a/api/smtp_configure.go b/api/smtp_configure.go new file mode 100644 index 0000000..779b40b --- /dev/null +++ b/api/smtp_configure.go @@ -0,0 +1,31 @@ +package main + +import ( + "net/smtp" + "os" +) + +var smtpConfigured bool +var smtpAuth smtp.Auth + +func smtpConfigure() error { + username := os.Getenv("SMTP_USERNAME") + password := os.Getenv("SMTP_PASSWORD") + host := os.Getenv("SMTP_HOST") + if username == "" || password == "" || host == "" { + logger.Warningf("smtp not configured, no emails will be sent") + smtpConfigured = false + return nil + } + + if os.Getenv("SMTP_FROM_ADDRESS") == "" { + logger.Errorf("SMTP_FROM_ADDRESS not set") + smtpConfigured = false + return errorMissingSmtpAddress + } + + logger.Infof("configuring smtp: %s", host) + smtpAuth = smtp.PlainAuth("", username, password, host) + smtpConfigured = true + return nil +} diff --git a/api/smtp_configure_test.go b/api/smtp_configure_test.go new file mode 100644 index 0000000..3194b84 --- /dev/null +++ b/api/smtp_configure_test.go @@ -0,0 +1,63 @@ +package main + +import ( + "os" + "testing" +) + +func cleanSmtpVars() { + for _, env := range []string{"SMTP_USERNAME", "SMTP_PASSWORD", "SMTP_HOST", "SMTP_FROM_ADDRESS"} { + os.Setenv(env, "") + } +} + +func TestSmtpConfigureBasics(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + os.Setenv("SMTP_USERNAME", "test@example.com") + os.Setenv("SMTP_PASSWORD", "hunter2") + os.Setenv("SMTP_HOST", "smtp.commento.io") + os.Setenv("SMTP_FROM_ADDRESS", "no-reply@commento.io") + + if err := smtpConfigure(); err != nil { + t.Errorf("unexpected error when configuring SMTP: %v", err) + return + } + + cleanSmtpVars() +} + +func TestSmtpConfigureEmptyHost(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + os.Setenv("SMTP_USERNAME", "test@example.com") + os.Setenv("SMTP_PASSWORD", "hunter2") + os.Setenv("SMTP_FROM_ADDRESS", "no-reply@commento.io") + + if err := smtpConfigure(); err != nil { + t.Errorf("unexpected error when configuring SMTP: %v", err) + return + } + + if smtpConfigured { + t.Errorf("SMTP configured when it should not be due to empty SMTP_HOST") + return + } + + cleanSmtpVars() +} + +func TestSmtpConfigureEmptyAddress(t *testing.T) { + failTestOnError(t, setupTestEnv()) + + os.Setenv("SMTP_USERNAME", "test@example.com") + os.Setenv("SMTP_PASSWORD", "hunter2") + os.Setenv("SMTP_HOST", "smtp.commento.io") + + if err := smtpConfigure(); err == nil { + t.Errorf("expected error not found; SMTP should not be configured when SMTP_FROM_ADDRESS is empty") + return + } + + cleanSmtpVars() +} diff --git a/api/smtp_owner_confirm_hex.go b/api/smtp_owner_confirm_hex.go new file mode 100644 index 0000000..ffc1454 --- /dev/null +++ b/api/smtp_owner_confirm_hex.go @@ -0,0 +1,28 @@ +package main + +import ( + "bytes" + "net/smtp" + "os" +) + +type ownerConfirmHexPlugs struct { + Origin string + ConfirmHex string +} + +func smtpOwnerConfirmHex(to string, toName string, confirmHex string) error { + var header bytes.Buffer + headerTemplate.Execute(&header, &headerPlugs{FromAddress: os.Getenv("SMTP_FROM_ADDRESS"), ToAddress: to, ToName: toName, Subject: "Please confirm your email address"}) + + var body bytes.Buffer + templates["confirm-hex"].Execute(&body, &ownerConfirmHexPlugs{Origin: os.Getenv("ORIGIN"), ConfirmHex: confirmHex}) + + err := smtp.SendMail(os.Getenv("SMTP_HOST"), smtpAuth, os.Getenv("SMTP_FROM_ADDRESS"), []string{to}, concat(header, body)) + if err != nil { + logger.Errorf("cannot send confirmation email: %v", err) + return errorCannotSendEmail + } + + return nil +} diff --git a/api/smtp_owner_reset_hex.go b/api/smtp_owner_reset_hex.go new file mode 100644 index 0000000..3bfaf4f --- /dev/null +++ b/api/smtp_owner_reset_hex.go @@ -0,0 +1,28 @@ +package main + +import ( + "bytes" + "net/smtp" + "os" +) + +type ownerResetHexPlugs struct { + Origin string + ResetHex string +} + +func smtpOwnerResetHex(to string, toName string, resetHex string) error { + var header bytes.Buffer + headerTemplate.Execute(&header, &headerPlugs{FromAddress: os.Getenv("SMTP_FROM_ADDRESS"), ToAddress: to, ToName: toName, Subject: "Reset your password"}) + + var body bytes.Buffer + templates["reset-hex"].Execute(&body, &ownerResetHexPlugs{Origin: os.Getenv("ORIGIN"), ResetHex: resetHex}) + + err := smtp.SendMail(os.Getenv("SMTP_HOST"), smtpAuth, os.Getenv("SMTP_FROM_ADDRESS"), []string{to}, concat(header, body)) + if err != nil { + logger.Errorf("cannot send reset email: %v", err) + return errorCannotSendEmail + } + + return nil +} diff --git a/api/smtp_templates.go b/api/smtp_templates.go new file mode 100644 index 0000000..2889121 --- /dev/null +++ b/api/smtp_templates.go @@ -0,0 +1,49 @@ +package main + +import ( + "fmt" + "html/template" +) + +var headerTemplate *template.Template + +type headerPlugs struct { + FromAddress string + ToName string + ToAddress string + Subject string +} + +var templates map[string]*template.Template + +func loadTemplates() error { + var err error + headerTemplate, err = template.New("header").Parse(`MIME-Version: 1.0 +Content-Type: text/html; charset=UTF-8 +From: {{.FromAddress}} +To: {{.ToName}} <{{.ToAddress}}> +Subject: {{.Subject}} + +`) + if err != nil { + logger.Fatalf("cannot parse header template: %v", err) + return errorMalformedTemplate + } + + names := []string{"confirm-hex"} + + templates = make(map[string]*template.Template) + + logger.Infof("loading templates: %v", names) + for _, name := range names { + var err error + templates[name] = template.New(name) + templates[name], err = template.ParseFiles(fmt.Sprintf("email/%s.html", name)) + if err != nil { + logger.Fatalf("cannot parse %s.html: %v\n", name, err) + return errorMalformedTemplate + } + } + + return nil +} diff --git a/api/testing.go b/api/testing.go new file mode 100644 index 0000000..f91777c --- /dev/null +++ b/api/testing.go @@ -0,0 +1,128 @@ +package main + +import ( + "fmt" + "github.com/op/go-logging" + "os" + "testing" +) + +func failTestOnError(t *testing.T, err error) { + if err != nil { + t.Errorf("failed test: %v", err) + } +} + +func getPublicTables() ([]string, error) { + statement := ` + SELECT tablename + FROM pg_tables + WHERE schemaname='public'; + ` + rows, err := db.Query(statement) + if err != nil { + fmt.Fprintf(os.Stderr, "cannot query public tables: %v", err) + return []string{}, err + } + + defer rows.Close() + + tables := []string{} + for rows.Next() { + var table string + if err = rows.Scan(&table); err != nil { + fmt.Fprintf(os.Stderr, "cannot scan table name: %v", err) + return []string{}, err + } + + tables = append(tables, table) + } + + return tables, nil +} + +func dropTables() error { + tables, err := getPublicTables() + if err != nil { + return err + } + + for _, table := range tables { + if table != "migrations" { + _, err = db.Exec(fmt.Sprintf("DROP TABLE %s;", table)) + if err != nil { + fmt.Fprintf(os.Stderr, "cannot drop %s: %v", table, err) + return err + } + } + } + + return nil +} + +func setupTestDatabase() error { + os.Setenv("POSTGRES", "postgres://postgres:postgres@0.0.0.0/commento_test?sslmode=disable") + + if err := connectDB(); err != nil { + return err + } + + if err := dropTables(); err != nil { + return err + } + + if err := performMigrationsFromDir("../db/"); err != nil { + return err + } + + return nil +} + +func clearTables() error { + tables, err := getPublicTables() + if err != nil { + return err + } + + for _, table := range tables { + _, err = db.Exec(fmt.Sprintf("DELETE FROM %s;", table)) + if err != nil { + fmt.Fprintf(os.Stderr, "cannot clear %s: %v", table, err) + return err + } + } + + return nil +} + +var setupComplete bool + +func setupTestEnv() error { + if !setupComplete { + setupComplete = true + + if err := createLogger(); err != nil { + return err + } + + // Print messages to console only if verbose. Sounds like a good idea to + // keep the console clean on `go test`. + if !testing.Verbose() { + logging.SetLevel(logging.CRITICAL, "") + } + + if err := setupTestDatabase(); err != nil { + return err + } + + if err := createMarkdownRenderer(); err != nil { + return err + } + } + + if err := clearTables(); err != nil { + return err + } + + return nil +} diff --git a/api/utils_crypto.go b/api/utils_crypto.go new file mode 100644 index 0000000..cd68572 --- /dev/null +++ b/api/utils_crypto.go @@ -0,0 +1,16 @@ +package main + +import ( + "crypto/rand" + "encoding/hex" +) + +func randomHex(n int) (string, error) { + b := make([]byte, n) + if _, err := rand.Read(b); err != nil { + logger.Errorf("cannot create %d-byte long random hex: %v\n", n, err) + return "", errorInternal + } + + return hex.EncodeToString(b), nil +} diff --git a/api/utils_crypto_test.go b/api/utils_crypto_test.go new file mode 100644 index 0000000..7ab50ae --- /dev/null +++ b/api/utils_crypto_test.go @@ -0,0 +1,29 @@ +package main + +import ( + "testing" +) + +func TestRandomHexBasics(t *testing.T) { + hex1, err := randomHex(32) + if err != nil { + t.Errorf("unexpected error creating hex: %v", err) + return + } + + if hex1 == "" { + t.Errorf("randomly generated hex empty") + return + } + + hex2, err := randomHex(32) + if err != nil { + t.Errorf("unexpected error creating hex: %v", err) + return + } + + if hex1 == hex2 { + t.Errorf("two randomly generated hexes found to be the same: '%s'", hex1) + return + } +} diff --git a/api/utils_http.go b/api/utils_http.go new file mode 100644 index 0000000..3b1c084 --- /dev/null +++ b/api/utils_http.go @@ -0,0 +1,45 @@ +package main + +import ( + "encoding/json" + "io/ioutil" + "net/http" + "reflect" +) + +type response map[string]interface{} + +// TODO: Add tests in utils_http_test.go + +func unmarshalBody(r *http.Request, x interface{}) error { + b, err := ioutil.ReadAll(r.Body) + if err != nil { + logger.Errorf("cannot read POST body: %v\n", err) + return errorInternal + } + + if err = json.Unmarshal(b, x); err != nil { + return errorInvalidJSONBody + } + + xv := reflect.Indirect(reflect.ValueOf(x)) + for i := 0; i < xv.NumField(); i++ { + if xv.Field(i).IsNil() { + return errorMissingField + } + } + + return nil +} + +func writeBody(w http.ResponseWriter, x map[string]interface{}) error { + resp, err := json.Marshal(x) + if err != nil { + w.Write([]byte(`{"success":false,"message":"Some internal error occurred"}`)) + logger.Errorf("cannot marshal response: %v\n") + return errorInternal + } + + w.Write(resp) + return nil +} diff --git a/api/utils_logging.go b/api/utils_logging.go new file mode 100644 index 0000000..811337d --- /dev/null +++ b/api/utils_logging.go @@ -0,0 +1,15 @@ +package main + +import ( + "github.com/op/go-logging" +) + +var logger *logging.Logger + +func createLogger() error { + format := logging.MustStringFormatter("[%{level}] %{shortfile} %{shortfunc}(): %{message}") + logging.SetFormatter(format) + logger = logging.MustGetLogger("commento") + + return nil +} diff --git a/api/utils_logging_test.go b/api/utils_logging_test.go new file mode 100644 index 0000000..dda40c4 --- /dev/null +++ b/api/utils_logging_test.go @@ -0,0 +1,21 @@ +package main + +import ( + "testing" +) + +func TestCreateLoggerBasics(t *testing.T) { + logger = nil + + if err := createLogger(); err != nil { + t.Errorf("unexpected error creating logger: %v", err) + return + } + + if logger == nil { + t.Errorf("logger null after createLogger()") + return + } + + logger.Debugf("test message please ignore") +} diff --git a/api/utils_misc.go b/api/utils_misc.go new file mode 100644 index 0000000..3aca03e --- /dev/null +++ b/api/utils_misc.go @@ -0,0 +1,16 @@ +package main + +import ( + "bytes" + "os" +) + +func concat(a bytes.Buffer, b bytes.Buffer) []byte { + return append(a.Bytes(), b.Bytes()...) +} + +func exitIfError(err error) { + if err != nil { + os.Exit(1) + } +} diff --git a/api/utils_sanitise.go b/api/utils_sanitise.go new file mode 100644 index 0000000..f4d7f57 --- /dev/null +++ b/api/utils_sanitise.go @@ -0,0 +1,35 @@ +package main + +import ( + "regexp" +) + +var prePlusMatch = regexp.MustCompile(`([^@\+]*)\+?(.*)@.*`) +var periodsMatch = regexp.MustCompile(`[\.]`) +var postAtMatch = regexp.MustCompile(`[^@]*(@.*)`) + +func stripEmail(email string) string { + postAt := postAtMatch.ReplaceAllString(email, `$1`) + prePlus := prePlusMatch.ReplaceAllString(email, `$1`) + strippedEmail := periodsMatch.ReplaceAllString(prePlus, ``) + postAt + + return strippedEmail +} + +var https = regexp.MustCompile(`(https?://)`) +var trailingSlash = regexp.MustCompile(`(/*$)`) + +func stripDomain(domain string) string { + noSlash := trailingSlash.ReplaceAllString(domain, ``) + noProtocol := https.ReplaceAllString(noSlash, ``) + + return noProtocol +} + +var path = regexp.MustCompile(`(https?://[^/]*)`) + +func stripPath(url string) string { + strippedPath := path.ReplaceAllString(url, ``) + + return strippedPath +} diff --git a/api/utils_sanitise_test.go b/api/utils_sanitise_test.go new file mode 100644 index 0000000..e183f7e --- /dev/null +++ b/api/utils_sanitise_test.go @@ -0,0 +1,20 @@ +package main + +import ( + "testing" +) + +func TestStripEmailBasics(t *testing.T) { + tests := map[string]string{ + "test@example.com": "test@example.com", + "test+strip@example.com": "test@example.com", + "test+strip+strip2@example.com": "test@example.com", + } + + for in, out := range tests { + if stripEmail(in) != out { + t.Errorf("for in=%s expected out=%s got out=%s", in, out, stripEmail(in)) + return + } + } +}