diff --git a/.gitignore b/.gitignore index 2a93830..e137b98 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ test.db +user_dbs/ todo-web diff --git a/api/api.go b/api/api.go index b5428a1..efbcb34 100644 --- a/api/api.go +++ b/api/api.go @@ -2,16 +2,18 @@ package api import ( "encoding/json" + "errors" "fmt" "net/http" "strconv" - "github.com/Cameron-Reed1/todo-web/db" + // "github.com/Cameron-Reed1/todo-web/db" "github.com/Cameron-Reed1/todo-web/types" ) func GetAll(w http.ResponseWriter, r *http.Request) { - todos, err := db.GetAllTodos() + // todos, err := db.GetAllTodos() + todos, err := []types.Todo(nil), errors.New("Broken right now :/") if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("{\"error\":\"Failed to get items\"}")) @@ -44,7 +46,8 @@ func GetTodo(w http.ResponseWriter, r *http.Request) { return } - todo, err := db.GetTodo(id) + // todo, err := db.GetTodo(id) + todo, err := types.Todo{Id: int64(id)}, errors.New("Broken right now :/") if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("{\"error\":\"No item for id\"}")) @@ -79,7 +82,8 @@ func AddTodo(w http.ResponseWriter, r *http.Request) { return } - err = db.AddTodo(&todo) + // err = db.AddTodo(&todo) + err = errors.New("Broken right now :/") if err != nil { w.WriteHeader(http.StatusInternalServerError) w.Write([]byte("{\"error\":\"Failed to add item\"}")) diff --git a/auth/auth.go b/auth/auth.go new file mode 100644 index 0000000..edd951e --- /dev/null +++ b/auth/auth.go @@ -0,0 +1,133 @@ +package auth + +import ( + "bytes" + "crypto/rand" + "encoding/base64" + "fmt" + + "github.com/Cameron-Reed1/todo-web/types" + "golang.org/x/crypto/argon2" + "golang.org/x/crypto/scrypt" +) + + +var algorithm argon2idHasher = argon2idHasher{ + hashLen: 64, + saltLen: 32, + time: 6, + memory: 24 * 1024, + threads: 1, +} + + +func Hash(password, salt []byte) (*HashSalt, error) { + return algorithm.Hash(password, salt) +} + +func Validate(hash, salt, password []byte) bool { + return algorithm.Validate(hash, salt, password) +} + +func CreateSessionFor(user_id int64) (*types.Session, error) { + buf := make([]byte, 32) + _, err := rand.Read(buf) + if err != nil { + return nil, err + } + + return &types.Session{ SessionId: base64.StdEncoding.EncodeToString(buf), UserId: user_id }, nil +} + + +func generateSalt(length uint) ([]byte, error) { + salt := make([]byte, length) + + _, err := rand.Read(salt) + if err != nil { + return nil, err + } + + return salt, nil +} + +type HashSalt struct { + Hash []byte + Salt []byte +} + +type hashAlgo interface { + Hash(password, salt []byte) ([]byte, error) + Validate(hash, salt, password []byte) bool +} + +type scryptHasher struct { + hashLen int + saltLen uint + cost int + blockSize int + parallelism int +} + +type argon2idHasher struct { + hashLen uint32 + saltLen uint + time uint32 + memory uint32 + threads uint8 +} + +func (s *scryptHasher) Hash(password, salt []byte) (*HashSalt, error) { + var err error + + if salt == nil || len(salt) == 0 { + salt, err = generateSalt(s.saltLen) + if err != nil { + fmt.Println("\x1b[31mError: Failed to generate a password salt\x1b[0m") + return nil, err + } + } + + hash, err := scrypt.Key(password, salt, s.cost, s.blockSize, s.parallelism, s.hashLen) + if err != nil { + return nil, err + } + + return &HashSalt{Hash: hash, Salt: salt}, nil +} + +func (s *scryptHasher) Validate(hash, salt, password []byte) bool { + hashed_password, err := s.Hash(password, salt) + if err != nil { + fmt.Println("\x1b[31mError: Failed to generate a password hash\x1b[0m") + return false + } + + return bytes.Equal(hash, hashed_password.Hash) +} + +func (a *argon2idHasher) Hash(password, salt []byte) (*HashSalt, error) { + var err error + + if salt == nil || len(salt) == 0 { + salt, err = generateSalt(a.saltLen) + if err != nil { + fmt.Println("\x1b[31mError: Failed to generate a password salt\x1b[0m") + return nil, err + } + } + + hash := argon2.IDKey(password, salt, a.time, a.memory, a.threads, a.hashLen) + + return &HashSalt{Hash: hash, Salt: salt}, nil +} + +func (s *argon2idHasher) Validate(hash, salt, password []byte) bool { + hashed_password, err := s.Hash(password, salt) + if err != nil { + fmt.Println("\x1b[31mError: Failed to generate a password hash\x1b[0m") + return false + } + + return bytes.Equal(hash, hashed_password.Hash) +} diff --git a/db/db.go b/db/db.go index c1e1136..75801e6 100644 --- a/db/db.go +++ b/db/db.go @@ -2,220 +2,10 @@ package db import ( "database/sql" - "log" - "time" _ "github.com/mattn/go-sqlite3" - - "github.com/Cameron-Reed1/todo-web/types" ) -var db *sql.DB - -func Open(path string) { - if db != nil { - log.Fatal("Cannot init DB twice!") - } - - var err error - db, err = sql.Open("sqlite3", path) - if err != nil { - log.Fatal(err) - } - - err = db.Ping() - if err != nil { - log.Fatal(err) - } - - query := ` - CREATE TABLE IF NOT EXISTS items ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - start INTEGER, - due INTEGER, - text TEXT NOT NULL, - completed INTEGER NOT NULL - );` - - _, err = db.Exec(query) -} - -func AddTodo(todo *types.Todo) error { - res, err := db.Exec("INSERT INTO items(start, due, text, completed) values(?, ?, ?, 0)", toNullInt64(todo.Start), toNullInt64(todo.Due), todo.Text) - if err != nil { - return err - } - - todo.Id, err = res.LastInsertId() - if err != nil { - return err - } - - return nil -} - -func GetTodo(id int) (types.Todo, error) { - var todo types.Todo - var start sql.NullInt64 - var due sql.NullInt64 - - row := db.QueryRow("SELECT * FROM items WHERE id=?", id) - err := row.Scan(&todo.Id, &start, &due, &todo.Text, &todo.Completed) - - todo.Start = fromNullInt64(start) - todo.Due = fromNullInt64(due) - - return todo, err -} - -func GetAllTodos() ([]types.Todo, error) { - var todos []types.Todo - - rows, err := db.Query("SELECT * FROM items") - if err != nil { - return nil, err - } - - for rows.Next() { - var todo types.Todo - var start sql.NullInt64 - var due sql.NullInt64 - - err = rows.Scan(&todo.Id, &start, &due, &todo.Text, &todo.Completed) - if err != nil { - return nil, err - } - - todo.Start = fromNullInt64(start) - todo.Due = fromNullInt64(due) - - todos = append(todos, todo) - } - - if err = rows.Err(); err != nil { - return nil, err - } - - return todos, nil -} - -func GetOverdueTodos() ([]types.Todo, error) { - var todos []types.Todo - - rows, err := db.Query("SELECT * FROM items WHERE due < ? AND due IS NOT NULL ORDER BY completed, due", time.Now().Unix()) - if err != nil { - return nil, err - } - - for rows.Next() { - var todo types.Todo - var start sql.NullInt64 - var due sql.NullInt64 - - err = rows.Scan(&todo.Id, &start, &due, &todo.Text, &todo.Completed) - if err != nil { - return nil, err - } - - todo.Start = fromNullInt64(start) - todo.Due = fromNullInt64(due) - - todos = append(todos, todo) - } - - if err = rows.Err(); err != nil { - return nil, err - } - - return todos, nil -} - -func GetTodayTodos() ([]types.Todo, error) { - var todos []types.Todo - - now := time.Now() - year, month, day := now.Date() - today := time.Date(year, month, day, 0, 0, 0, 0, time.Local) - rows, err := db.Query("SELECT * FROM items WHERE (start <= ? OR start IS NULL) AND (due >= ? OR due IS NULL) ORDER BY completed, due NULLS LAST", today.Unix(), now.Unix()) - if err != nil { - return nil, err - } - - for rows.Next() { - var todo types.Todo - var start sql.NullInt64 - var due sql.NullInt64 - - err = rows.Scan(&todo.Id, &start, &due, &todo.Text, &todo.Completed) - if err != nil { - return nil, err - } - - todo.Start = fromNullInt64(start) - todo.Due = fromNullInt64(due) - - todos = append(todos, todo) - } - - if err = rows.Err(); err != nil { - return nil, err - } - - return todos, nil -} - -func GetUpcomingTodos() ([]types.Todo, error) { - var todos []types.Todo - - year, month, day := time.Now().Date() - today := time.Date(year, month, day, 0, 0, 0, 0, time.Local) - rows, err := db.Query("SELECT * FROM items WHERE start > ? ORDER BY completed, start", today.Unix()) - if err != nil { - return nil, err - } - - for rows.Next() { - var todo types.Todo - var start sql.NullInt64 - var due sql.NullInt64 - - err = rows.Scan(&todo.Id, &start, &due, &todo.Text, &todo.Completed) - if err != nil { - return nil, err - } - - todo.Start = fromNullInt64(start) - todo.Due = fromNullInt64(due) - - todos = append(todos, todo) - } - - if err = rows.Err(); err != nil { - return nil, err - } - - return todos, nil -} - -func UpdateTodo(newValues types.Todo) error { - _, err := db.Exec("UPDATE items SET start=?, due=?, text=? WHERE id=?", toNullInt64(newValues.Start), toNullInt64(newValues.Due), newValues.Text, newValues.Id) - return err; -} - -func SetCompleted(id int, completed bool) error { - _, err := db.Exec("UPDATE items SET completed=? WHERE id=?", completed, id) - return err -} - -func DeleteTodo(id int) error { - _, err := db.Exec("DELETE FROM items WHERE id=?", id) - return err -} - -func Close() { - db.Close() -} - func toNullInt64(num int64) sql.NullInt64 { if num == 0 { return sql.NullInt64{Int64: 0, Valid: false} diff --git a/db/main_db.go b/db/main_db.go new file mode 100644 index 0000000..2cce30a --- /dev/null +++ b/db/main_db.go @@ -0,0 +1,133 @@ +package db + +import ( + "database/sql" + "encoding/hex" + "log" + + "github.com/Cameron-Reed1/todo-web/auth" + "github.com/Cameron-Reed1/todo-web/types" + _ "github.com/mattn/go-sqlite3" +) + + +var main_db *sql.DB + + +func OpenMainDB(path string) { + if main_db != nil { + log.Fatal("Cannot open main DB twice!") + } + + var err error + main_db, err = sql.Open("sqlite3", path) + if err != nil { + log.Fatal(err) + } + + err = main_db.Ping() + if err != nil { + log.Fatal(err) + } + + query := ` + CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL UNIQUE, + password_hash TEXT NOT NULL, + password_salt TEXT NOT NULL + ); + CREATE TABLE IF NOT EXISTS sessions ( + sessionId TEXT NOT NULL, + user_id INTEGER NOT NULL, + FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE + );` + + _, err = main_db.Exec(query) + if err != nil { + log.Fatal(err) + } +} + +func CloseMainDB() { + main_db.Close() +} + +func CreateUser(username string, password_hash, password_salt []byte) (int64, error) { + hex_hash := hex.EncodeToString(password_hash) + hex_salt := hex.EncodeToString(password_salt) + + res, err := main_db.Exec("INSERT INTO users(username, password_hash, password_salt) values(?, ?, ?)", username, hex_hash, hex_salt) + if err != nil { + return 0, err + } + + return res.LastInsertId() +} + +func GetUserPassHash(username string) (int64, *auth.HashSalt, error) { + hashSalt := auth.HashSalt{} + var user_id int64 + var hex_hash string + var hex_salt string + + row := main_db.QueryRow("SELECT id, password_hash, password_salt FROM users WHERE username=?", username) + err := row.Scan(&user_id, &hex_hash, &hex_salt) + if err != nil { + return 0, nil, err + } + + hashSalt.Hash, err = hex.DecodeString(hex_hash) + if err != nil { + return 0, nil, err + } + hashSalt.Salt, err = hex.DecodeString(hex_salt) + if err != nil { + return 0, nil, err + } + + return user_id, &hashSalt, nil +} + +func DeleteUser(username string) error { + _, err := main_db.Exec("DELETE FROM users WHERE username=?", username) + return err +} + + +func AddSession(session *types.Session) error { + // fmt.Printf("New session: %s, %d\n", session.SessionId, session.UserId) + _, err := main_db.Exec("INSERT INTO sessions(sessionId, user_id) values(?, ?)", session.SessionId, session.UserId) + // fmt.Printf("Err: %v\n", err) + return err +} + +func GetUserFromSession(sessionId string) (string, error) { + var username string + + row := main_db.QueryRow("SELECT username FROM sessions INNER JOIN users ON sessions.user_id = users.id WHERE sessionId=?", sessionId) + err := row.Scan(&username) + if err != nil { + return "", err + } + + return username, nil +} + +func GetSession(sessionId string) (*types.Session, error) { + var session types.Session + + row := main_db.QueryRow("SELECT user_id FROM sessions WHERE sessionId=?", sessionId) + session.SessionId = sessionId + err := row.Scan(&session.UserId) + if err != nil { + return nil, err + } + + return &session, nil +} + +func DeleteSession(sessionId string) error { + _, err := main_db.Exec("DELETE FROM sessions WHERE sessionId=?", sessionId) + return err +} diff --git a/db/user_db.go b/db/user_db.go new file mode 100644 index 0000000..73677ba --- /dev/null +++ b/db/user_db.go @@ -0,0 +1,233 @@ +package db + +import ( + "database/sql" + "errors" + "os" + "path" + "strings" + "time" + + _ "github.com/mattn/go-sqlite3" + + "github.com/Cameron-Reed1/todo-web/types" +) + +var userDbDir string + +type UserDB struct { + DB *sql.DB +} + +func SetUserDBDir(dir string) error { + os.MkdirAll(dir, 0700) + userDbDir = dir + return nil +} + +func OpenUserDB(username string) (*UserDB, error) { + if strings.Contains(username, ".") || strings.Contains(username, "/") { + return nil, errors.New("Invalid username") + } + + path := path.Join(userDbDir, username + ".db") + + db, err := sql.Open("sqlite3", path) + if err != nil { + return nil, err + } + + err = db.Ping() + if err != nil { + return nil, err + } + + query := ` + CREATE TABLE IF NOT EXISTS items ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + start INTEGER, + due INTEGER, + text TEXT NOT NULL, + completed INTEGER NOT NULL + );` + + _, err = db.Exec(query) + + return &UserDB{DB: db}, err +} + +func (user_db *UserDB) Close() error { + return user_db.DB.Close() +} + +func (user_db *UserDB) AddTodo(todo *types.Todo) error { + res, err := user_db.DB.Exec("INSERT INTO items(start, due, text, completed) values(?, ?, ?, 0)", toNullInt64(todo.Start), toNullInt64(todo.Due), todo.Text) + if err != nil { + return err + } + + todo.Id, err = res.LastInsertId() + if err != nil { + return err + } + + return nil +} + +func (user_db *UserDB) GetTodo(id int) (types.Todo, error) { + var todo types.Todo + var start sql.NullInt64 + var due sql.NullInt64 + + row := user_db.DB.QueryRow("SELECT * FROM items WHERE id=?", id) + err := row.Scan(&todo.Id, &start, &due, &todo.Text, &todo.Completed) + + todo.Start = fromNullInt64(start) + todo.Due = fromNullInt64(due) + + return todo, err +} + +func (user_db *UserDB) GetAllTodos() ([]types.Todo, error) { + var todos []types.Todo + + rows, err := user_db.DB.Query("SELECT * FROM items") + if err != nil { + return nil, err + } + + for rows.Next() { + var todo types.Todo + var start sql.NullInt64 + var due sql.NullInt64 + + err = rows.Scan(&todo.Id, &start, &due, &todo.Text, &todo.Completed) + if err != nil { + return nil, err + } + + todo.Start = fromNullInt64(start) + todo.Due = fromNullInt64(due) + + todos = append(todos, todo) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return todos, nil +} + +func (user_db *UserDB) GetOverdueTodos() ([]types.Todo, error) { + var todos []types.Todo + + rows, err := user_db.DB.Query("SELECT * FROM items WHERE due < ? AND due IS NOT NULL ORDER BY completed, due", time.Now().Unix()) + if err != nil { + return nil, err + } + + for rows.Next() { + var todo types.Todo + var start sql.NullInt64 + var due sql.NullInt64 + + err = rows.Scan(&todo.Id, &start, &due, &todo.Text, &todo.Completed) + if err != nil { + return nil, err + } + + todo.Start = fromNullInt64(start) + todo.Due = fromNullInt64(due) + + todos = append(todos, todo) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return todos, nil +} + +func (user_db *UserDB) GetTodayTodos() ([]types.Todo, error) { + var todos []types.Todo + + now := time.Now() + year, month, day := now.Date() + today := time.Date(year, month, day, 0, 0, 0, 0, time.Local) + rows, err := user_db.DB.Query("SELECT * FROM items WHERE (start <= ? OR start IS NULL) AND (due >= ? OR due IS NULL) ORDER BY completed, due NULLS LAST", today.Unix(), now.Unix()) + if err != nil { + return nil, err + } + + for rows.Next() { + var todo types.Todo + var start sql.NullInt64 + var due sql.NullInt64 + + err = rows.Scan(&todo.Id, &start, &due, &todo.Text, &todo.Completed) + if err != nil { + return nil, err + } + + todo.Start = fromNullInt64(start) + todo.Due = fromNullInt64(due) + + todos = append(todos, todo) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return todos, nil +} + +func (user_db *UserDB) GetUpcomingTodos() ([]types.Todo, error) { + var todos []types.Todo + + year, month, day := time.Now().Date() + today := time.Date(year, month, day, 0, 0, 0, 0, time.Local) + rows, err := user_db.DB.Query("SELECT * FROM items WHERE start > ? ORDER BY completed, start", today.Unix()) + if err != nil { + return nil, err + } + + for rows.Next() { + var todo types.Todo + var start sql.NullInt64 + var due sql.NullInt64 + + err = rows.Scan(&todo.Id, &start, &due, &todo.Text, &todo.Completed) + if err != nil { + return nil, err + } + + todo.Start = fromNullInt64(start) + todo.Due = fromNullInt64(due) + + todos = append(todos, todo) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return todos, nil +} + +func (user_db *UserDB) UpdateTodo(newValues types.Todo) error { + _, err := user_db.DB.Exec("UPDATE items SET start=?, due=?, text=? WHERE id=?", toNullInt64(newValues.Start), toNullInt64(newValues.Due), newValues.Text, newValues.Id) + return err; +} + +func (user_db *UserDB) SetCompleted(id int, completed bool) error { + _, err := user_db.DB.Exec("UPDATE items SET completed=? WHERE id=?", completed, id) + return err +} + +func (user_db *UserDB) DeleteTodo(id int) error { + _, err := user_db.DB.Exec("DELETE FROM items WHERE id=?", id) + return err +} diff --git a/go.mod b/go.mod index 7242360..890fc64 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,10 @@ module github.com/Cameron-Reed1/todo-web go 1.22.6 -require github.com/mattn/go-sqlite3 v1.14.22 +require ( + github.com/a-h/templ v0.2.747 + github.com/mattn/go-sqlite3 v1.14.22 + golang.org/x/crypto v0.27.0 +) -require github.com/a-h/templ v0.2.747 // indirect +require golang.org/x/sys v0.25.0 // indirect diff --git a/go.sum b/go.sum index 3f9af5c..16990f5 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,10 @@ github.com/a-h/templ v0.2.747 h1:D0dQ2lxC3W7Dxl6fxQ/1zZHBQslSkTSvl5FxP/CfdKg= github.com/a-h/templ v0.2.747/go.mod h1:69ObQIbrcuwPCU32ohNaWce3Cb7qM5GMiqN1K+2yop4= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +golang.org/x/crypto v0.27.0 h1:GXm2NjJrPaiv/h1tb2UH8QfgC/hOf/+z0p6PT8o1w7A= +golang.org/x/crypto v0.27.0/go.mod h1:1Xngt8kV6Dvbssa53Ziq6Eqn0HqbZi5Z6R0ZpwQzt70= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/main.go b/main.go index 1043d3e..43fce39 100644 --- a/main.go +++ b/main.go @@ -12,32 +12,26 @@ import ( ) func main() { - db_path := flag.String("db", "./test.db", "Path to the sqlite3 database") + db_path := flag.String("db", "./main.db", "Path to the main sqlite3 database") + user_db_dir := flag.String("user-dbs", "./user_dbs", "Path to the directory containing per-user sqlite3 databases") bind_port := flag.Int("p", 8080, "Port to bind to") bind_addr := flag.String("a", "0.0.0.0", "Address to bind to") static_dir := flag.String("static", "./static", "Path to static files") noFront := flag.Bool("no-frontend", false, "Disable the frontend endpoints") - a := false; noBack := &a // flag.Bool("no-backend", false, "Disable the backend endpoints") + // a := false; noBack := &a // flag.Bool("no-backend", false, "Disable the backend endpoints") // This didn't really make sense flag.Parse() mux := http.NewServeMux() - if *noFront && *noBack { - fmt.Println("What do you want me to do?") - return - } - if !*noFront { addFrontendEndpoints(mux, *static_dir) } + addBackendEndpoints(mux) - if !*noBack { - addBackendEndpoints(mux) - } - - db.Open(*db_path) - defer db.Close() + db.SetUserDBDir(*user_db_dir) + db.OpenMainDB(*db_path) + defer db.CloseMainDB() addr := fmt.Sprintf("%s:%d", *bind_addr, *bind_port) server := http.Server{ Addr: addr, Handler: mux } @@ -60,6 +54,9 @@ func addFrontendEndpoints(mux *http.ServeMux, static_path string) { mux.HandleFunc("/overdue", pages.OverdueFragment) mux.HandleFunc("/today", pages.TodayFragment) mux.HandleFunc("/upcoming", pages.UpcomingFragment) + mux.HandleFunc("/login", pages.Login) + mux.HandleFunc("/create-account", pages.CreateAccount) + mux.HandleFunc("POST /logout", pages.Logout) mux.HandleFunc("DELETE /delete/{id}", pages.DeleteItem) mux.HandleFunc("PATCH /set/{id}", pages.SetItemCompleted) mux.HandleFunc("PUT /update", pages.UpdateItem) diff --git a/pages/fragments.go b/pages/fragments.go index 610d99a..ca45fd4 100644 --- a/pages/fragments.go +++ b/pages/fragments.go @@ -3,12 +3,18 @@ package pages import ( "net/http" - "github.com/Cameron-Reed1/todo-web/db" "github.com/Cameron-Reed1/todo-web/pages/templates" ) func OverdueFragment(w http.ResponseWriter, r *http.Request) { - items, err := db.GetOverdueTodos() + user_db, err := validateSessionAndGetUserDB(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + defer user_db.Close() + + items, err := user_db.GetOverdueTodos() if err != nil { w.WriteHeader(http.StatusInternalServerError) return @@ -18,7 +24,14 @@ func OverdueFragment(w http.ResponseWriter, r *http.Request) { } func TodayFragment(w http.ResponseWriter, r *http.Request) { - items, err := db.GetTodayTodos() + user_db, err := validateSessionAndGetUserDB(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + defer user_db.Close() + + items, err := user_db.GetTodayTodos() if err != nil { w.WriteHeader(http.StatusInternalServerError) return @@ -28,7 +41,14 @@ func TodayFragment(w http.ResponseWriter, r *http.Request) { } func UpcomingFragment(w http.ResponseWriter, r *http.Request) { - items, err := db.GetUpcomingTodos() + user_db, err := validateSessionAndGetUserDB(r) + if err != nil { + w.WriteHeader(http.StatusUnauthorized) + return + } + defer user_db.Close() + + items, err := user_db.GetUpcomingTodos() if err != nil { w.WriteHeader(http.StatusInternalServerError) return diff --git a/pages/login.go b/pages/login.go new file mode 100644 index 0000000..c2cadc9 --- /dev/null +++ b/pages/login.go @@ -0,0 +1,108 @@ +package pages + +import ( + "net/http" + + "github.com/Cameron-Reed1/todo-web/auth" + "github.com/Cameron-Reed1/todo-web/db" + "github.com/Cameron-Reed1/todo-web/pages/templates" +) + +func Login(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + if _, err := validateSession(r); err == nil { + w.Header().Add("Location", "/") + w.WriteHeader(http.StatusSeeOther) + } else { + templates.LoginPage().Render(r.Context(), w) + } + return + } + + username := r.FormValue("username") + password := r.FormValue("password") + stay_logged := r.FormValue("stay-logged-in") == "on" + + if username == "" || password == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + userId, hashSalt, err := db.GetUserPassHash(username) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + if auth.Validate(hashSalt.Hash, hashSalt.Salt, []byte(password)) { + session, err := createSession(userId) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Add("Set-Cookie", session.ToCookie(stay_logged)) + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusUnauthorized) + } +} + +func CreateAccount(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" { + if _, err := validateSession(r); err == nil { + w.Header().Add("Location", "/") + w.WriteHeader(http.StatusSeeOther) + } else { + templates.CreateAccountBox().Render(r.Context(), w) + } + return + } + + username := r.FormValue("username") + password := r.FormValue("password") + + if username == "" || password == "" { + w.WriteHeader(http.StatusBadRequest) + return + } + + // TODO: validate credentials + // Ensure the username is valid and is not taken + // Ensure that the password meets requirements + + hashSalt, err := auth.Hash([]byte(password), nil) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + user_id, err := db.CreateUser(username, hashSalt.Hash, hashSalt.Salt) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + session, err := createSession(user_id) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + w.Header().Add("Set-Cookie", session.ToCookie(false)) +} + +func Logout(w http.ResponseWriter, r *http.Request) { + session, err := validateSession(r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + err = db.DeleteSession(session) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + + w.Header().Add("Set-Cookie", "session=;expires=Thu, 01 Jan 1970 00:00:00 UTC;samesite=strict;secure;HTTPonly") + w.WriteHeader(http.StatusOK) +} diff --git a/pages/root.go b/pages/root.go index 81594e0..594ccfd 100644 --- a/pages/root.go +++ b/pages/root.go @@ -7,5 +7,12 @@ import ( ) func RootPage(w http.ResponseWriter, r *http.Request) { - templates.RootPage().Render(r.Context(), w) + username, err := validateSessionAndGetUsername(r) + if err != nil { + w.Header().Add("Location", "/login") + w.WriteHeader(http.StatusFound) + return + } + + templates.RootPage(username).Render(r.Context(), w) } diff --git a/pages/templates/login.templ b/pages/templates/login.templ new file mode 100644 index 0000000..a0ea5ae --- /dev/null +++ b/pages/templates/login.templ @@ -0,0 +1,70 @@ +package templates + +templ loginSkeleton() { + + +
+