summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore21
-rw-r--r--api/handlers.go166
-rw-r--r--db/db.go210
-rw-r--r--go.mod8
-rw-r--r--go.sum4
-rw-r--r--main.go119
6 files changed, 528 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..70367ca
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,21 @@
+# Build output
+*.exe
+*.exe~
+*.test
+*.out
+
+# Dependencies
+!vendor/
+
+# IDE
+.idea/
+.vscode/
+*.swp
+*.swo
+
+# OS
+.DS_Store
+Thumbs.db
+# Config (contains secrets)
+config.json
+database_password
diff --git a/api/handlers.go b/api/handlers.go
new file mode 100644
index 0000000..0d23bdb
--- /dev/null
+++ b/api/handlers.go
@@ -0,0 +1,166 @@
+package api
+
+import (
+ "context"
+ "encoding/json"
+ "net/http"
+ "time"
+
+ "wikiapiserver/db"
+)
+
+const defaultTimeout = 5 * time.Second
+
+// Handler holds the DB dependency for all HTTP handlers.
+type Handler struct {
+ db *db.DB
+}
+
+// NewHandler creates a Handler backed by the given DB.
+func NewHandler(database *db.DB) *Handler {
+ return &Handler{db: database}
+}
+
+// --- request/response types ---
+
+type errResp struct {
+ Error string `json:"error"`
+}
+
+type registerReq struct {
+ Username string `json:"username"`
+ Password string `json:"password"`
+}
+
+type loginReq struct {
+ Username string `json:"username"`
+ Password string `json:"password"`
+}
+
+type refreshReq struct {
+ RefreshToken string `json:"refresh_token"`
+}
+
+// --- helper writers ---
+
+func writeJSON(w http.ResponseWriter, code int, v any) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(code)
+ json.NewEncoder(w).Encode(v) //nolint:errcheck
+}
+
+func badRequest(w http.ResponseWriter, msg string) {
+ writeJSON(w, http.StatusBadRequest, errResp{Error: msg})
+}
+
+func unauthorized(w http.ResponseWriter) {
+ writeJSON(w, http.StatusUnauthorized, errResp{Error: "unauthorized"})
+}
+
+func serverError(w http.ResponseWriter, msg string) {
+ writeJSON(w, http.StatusInternalServerError, errResp{Error: msg})
+}
+
+// --- Register: POST /register ---
+
+func (h *Handler) Register(w http.ResponseWriter, r *http.Request) {
+ ctx, cancel := context.WithTimeout(r.Context(), defaultTimeout)
+ defer cancel()
+
+ var req registerReq
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ badRequest(w, "invalid JSON")
+ return
+ }
+
+ if req.Username == "" || req.Password == "" {
+ badRequest(w, "username and password are required")
+ return
+ }
+
+ acct, err := h.db.CreateAccount(ctx, req.Username, req.Password)
+ if err != nil {
+ if err.Error() == "username already exists" {
+ badRequest(w, "username already exists")
+ return
+ }
+ serverError(w, "could not create account")
+ return
+ }
+
+ writeJSON(w, http.StatusCreated, acct)
+}
+
+// --- Login: POST /login ---
+
+func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
+ ctx, cancel := context.WithTimeout(r.Context(), defaultTimeout)
+ defer cancel()
+
+ var req loginReq
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ badRequest(w, "invalid JSON")
+ return
+ }
+
+ if req.Username == "" || req.Password == "" {
+ badRequest(w, "username and password are required")
+ return
+ }
+
+ acct, err := h.db.Authenticate(ctx, req.Username, req.Password)
+ if err != nil {
+ if err.Error() == "invalid credentials" {
+ unauthorized(w)
+ return
+ }
+ serverError(w, "authentication failed")
+ return
+ }
+
+ writeJSON(w, http.StatusOK, acct)
+}
+
+// --- Refresh: POST /refresh ---
+
+func (h *Handler) Refresh(w http.ResponseWriter, r *http.Request) {
+ ctx, cancel := context.WithTimeout(r.Context(), defaultTimeout)
+ defer cancel()
+
+ var req refreshReq
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ badRequest(w, "invalid JSON")
+ return
+ }
+
+ if req.RefreshToken == "" {
+ badRequest(w, "refresh_token is required")
+ return
+ }
+
+ acct, err := h.db.RefreshByToken(ctx, req.RefreshToken)
+ if err != nil {
+ if err.Error() == "invalid refresh token" {
+ unauthorized(w)
+ return
+ }
+ serverError(w, "could not refresh token")
+ return
+ }
+
+ writeJSON(w, http.StatusOK, acct)
+}
+
+// --- Health: GET /health ---
+
+func (h *Handler) Health(w http.ResponseWriter, r *http.Request) {
+ ctx, cancel := context.WithTimeout(r.Context(), defaultTimeout)
+ defer cancel()
+
+ if err := h.db.Ping(ctx); err != nil {
+ writeJSON(w, http.StatusServiceUnavailable, errResp{Error: "database unavailable"})
+ return
+ }
+
+ writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
+}
diff --git a/db/db.go b/db/db.go
new file mode 100644
index 0000000..e011334
--- /dev/null
+++ b/db/db.go
@@ -0,0 +1,210 @@
+package db
+
+import (
+ "context"
+ "crypto/rand"
+ "crypto/sha256"
+ "database/sql"
+ "encoding/hex"
+ "errors"
+ "fmt"
+ "strings"
+ "time"
+
+ _ "github.com/go-sql-driver/mysql"
+)
+
+const (
+ tokenLength = 32 // bytes → 64 hex chars
+ accessTokenTTL = 24 * time.Hour
+)
+
+// Account represents a verified account with its current tokens.
+type Account struct {
+ ID int64 `json:"-"`
+ Username string `json:"username"`
+ RefreshToken string `json:"refresh_token"`
+ AccessToken string `json:"access_token"`
+ AccessTokenExpiry time.Time `json:"access_token_expires"`
+ CreatedAt time.Time `json:"created_at"`
+}
+
+// DB wraps sql.DB for wikiapiserver.
+type DB struct {
+ conn *sql.DB
+}
+
+// Connect opens a MariaDB connection, pings it, and returns a ready DB.
+func Connect(dsn string) (*DB, error) {
+ sqlDB, err := sql.Open("mysql", dsn)
+ if err != nil {
+ return nil, fmt.Errorf("open mysql: %w", err)
+ }
+ if err := sqlDB.Ping(); err != nil {
+ return nil, fmt.Errorf("ping mysql: %w", err)
+ }
+ sqlDB.SetMaxOpenConns(25)
+ sqlDB.SetMaxIdleConns(10)
+ sqlDB.SetConnMaxLifetime(5 * time.Minute)
+ return &DB{conn: sqlDB}, nil
+}
+
+// Close shuts down the connection pool.
+func (d *DB) Close() error {
+ return d.conn.Close()
+}
+
+// Ping checks database liveness.
+func (d *DB) Ping(ctx context.Context) error {
+ return d.conn.PingContext(ctx)
+}
+
+// --- crypto helpers ---
+
+func randomHex(n int) (string, error) {
+ b := make([]byte, n)
+ if _, err := rand.Read(b); err != nil {
+ return "", fmt.Errorf("crypto rand: %w", err)
+ }
+ return hex.EncodeToString(b), nil
+}
+
+func sha256hex(s string) string {
+ h := sha256.Sum256([]byte(s))
+ return hex.EncodeToString(h[:])
+}
+
+// --- error helpers ---
+
+func isDupKeyError(err error) bool {
+ return err != nil && strings.Contains(err.Error(), "1062")
+}
+
+// --- queries ---
+
+// CreateAccount inserts a new row with hashed credentials and fresh tokens.
+func (d *DB) CreateAccount(ctx context.Context, username, plaintextPW string) (*Account, error) {
+ rt, err := randomHex(tokenLength)
+ if err != nil {
+ return nil, err
+ }
+ at, err := randomHex(tokenLength)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err := d.conn.ExecContext(ctx,
+ `INSERT INTO account (username, password, refresh_token, access_token, access_token_created)
+ VALUES (?, SHA2(?, 256), SHA2(?, 256), SHA2(?, 256), NOW())`,
+ username, plaintextPW, rt, at,
+ )
+ if err != nil {
+ if isDupKeyError(err) {
+ return nil, errors.New("username already exists")
+ }
+ return nil, fmt.Errorf("insert account: %w", err)
+ }
+
+ id, err := res.LastInsertId()
+ if err != nil {
+ return nil, err
+ }
+
+ now := time.Now()
+ return &Account{
+ ID: id,
+ Username: username,
+ RefreshToken: rt,
+ AccessToken: at,
+ AccessTokenExpiry: now.Add(accessTokenTTL),
+ CreatedAt: now,
+ }, nil
+}
+
+// Authenticate verifies plaintext credentials and returns fresh tokens.
+func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*Account, error) {
+ var storedHash string
+ err := d.conn.QueryRowContext(ctx,
+ `SELECT password FROM account WHERE username = ?`, username,
+ ).Scan(&storedHash)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errors.New("invalid credentials")
+ }
+ return nil, fmt.Errorf("query user: %w", err)
+ }
+
+ if storedHash != sha256hex(plaintextPW) {
+ return nil, errors.New("invalid credentials")
+ }
+
+ // Look up account id, then rotate tokens
+ var id int64
+ if err := d.conn.QueryRowContext(ctx,
+ `SELECT id FROM account WHERE username = ?`, username,
+ ).Scan(&id); err != nil {
+ return nil, fmt.Errorf("lookup id: %w", err)
+ }
+
+ return d.RotateTokens(ctx, id)
+}
+
+// RotateTokens generates new refresh + access tokens for an account.
+func (d *DB) RotateTokens(ctx context.Context, id int64) (*Account, error) {
+ rt, err := randomHex(tokenLength)
+ if err != nil {
+ return nil, err
+ }
+ at, err := randomHex(tokenLength)
+ if err != nil {
+ return nil, err
+ }
+
+ res, err := d.conn.ExecContext(ctx,
+ `UPDATE account SET refresh_token = SHA2(?, 256), access_token = SHA2(?, 256), access_token_created = NOW()
+ WHERE id = ?`, rt, at, id)
+ if err != nil {
+ return nil, fmt.Errorf("rotate tokens: %w", err)
+ }
+ if affected, _ := res.RowsAffected(); affected == 0 {
+ return nil, errors.New("account not found")
+ }
+
+ var username string
+ if err := d.conn.QueryRowContext(ctx,
+ `SELECT username FROM account WHERE id = ?`, id,
+ ).Scan(&username); err != nil {
+ return nil, fmt.Errorf("lookup username: %w", err)
+ }
+
+ now := time.Now()
+ return &Account{
+ ID: id,
+ Username: username,
+ RefreshToken: rt,
+ AccessToken: at,
+ AccessTokenExpiry: now.Add(accessTokenTTL),
+ CreatedAt: now,
+ }, nil
+}
+
+// RefreshByToken looks up an account by its refresh token and rotates both tokens.
+func (d *DB) RefreshByToken(ctx context.Context, refreshToken string) (*Account, error) {
+ var id int64
+ err := d.conn.QueryRowContext(ctx,
+ `SELECT id FROM account WHERE refresh_token = SHA2(?, 256)`, refreshToken,
+ ).Scan(&id)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errors.New("invalid refresh token")
+ }
+ return nil, fmt.Errorf("lookup refresh token: %w", err)
+ }
+ return d.RotateTokens(ctx, id)
+}
+
+// HealthCheck runs a trivial query to verify DB liveness.
+func (d *DB) HealthCheck(ctx context.Context) error {
+ _, err := d.conn.ExecContext(ctx, "SELECT 1")
+ return err
+}
diff --git a/go.mod b/go.mod
new file mode 100644
index 0000000..c096aa8
--- /dev/null
+++ b/go.mod
@@ -0,0 +1,8 @@
+module wikiapiserver
+
+go 1.26.4
+
+require (
+ filippo.io/edwards25519 v1.2.0 // indirect
+ github.com/go-sql-driver/mysql v1.10.0 // indirect
+)
diff --git a/go.sum b/go.sum
new file mode 100644
index 0000000..1a983c8
--- /dev/null
+++ b/go.sum
@@ -0,0 +1,4 @@
+filippo.io/edwards25519 v1.2.0 h1:crnVqOiS4jqYleHd9vaKZ+HKtHfllngJIiOpNpoJsjo=
+filippo.io/edwards25519 v1.2.0/go.mod h1:xzAOLCNug/yB62zG1bQ8uziwrIqIuxhctzJT18Q77mc=
+github.com/go-sql-driver/mysql v1.10.0 h1:Q+1LV8DkHJvSYAdR83XzuhDaTykuDx0l6fkXxoWCWfw=
+github.com/go-sql-driver/mysql v1.10.0/go.mod h1:M+cqaI7+xxXGG9swrdeUIoPG3Y3KCkF0pZej+SK+nWk=
diff --git a/main.go b/main.go
new file mode 100644
index 0000000..813ac41
--- /dev/null
+++ b/main.go
@@ -0,0 +1,119 @@
+package main
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "os"
+ "os/signal"
+ "path/filepath"
+ "syscall"
+ "time"
+
+ "wikiapiserver/api"
+ "wikiapiserver/db"
+)
+
+// Config is loaded from config.json at startup.
+type Config struct {
+ Database struct {
+ Host string `json:"host"`
+ Username string `json:"username"`
+ Password string `json:"password"`
+ Name string `json:"name"`
+ } `json:"database"`
+ Server struct {
+ Port int `json:"port"`
+ } `json:"server"`
+}
+
+func loadConfig() (*Config, error) {
+ // Try CWD first (works for go run), then executable dir (works for built binary)
+ for _, dir := range []string{".", ""} {
+ if dir == "" {
+ if exe, err := os.Executable(); err == nil {
+ dir = filepath.Dir(exe)
+ } else {
+ continue
+ }
+ }
+
+ path := filepath.Join(dir, "config.json")
+ f, err := os.Open(path)
+ if err != nil {
+ continue
+ }
+
+ var cfg Config
+ if err := json.NewDecoder(f).Decode(&cfg); err != nil {
+ f.Close()
+ return nil, fmt.Errorf("decode %s: %w", path, err)
+ }
+ f.Close()
+ return &cfg, nil
+ }
+
+ return nil, fmt.Errorf("config.json not found in . or next to executable")
+}
+
+func buildDSN(cfg *Config) string {
+ return fmt.Sprintf("%s:%s@tcp(%s)/%s",
+ cfg.Database.Username,
+ cfg.Database.Password,
+ cfg.Database.Host,
+ cfg.Database.Name,
+ )
+}
+
+func main() {
+ cfg, err := loadConfig()
+ if err != nil {
+ log.Fatalf("config: %v", err)
+ }
+
+ database, err := db.Connect(buildDSN(cfg))
+ if err != nil {
+ log.Fatalf("db: %v", err)
+ }
+ defer database.Close()
+
+ handler := api.NewHandler(database)
+
+ mux := http.NewServeMux()
+ mux.HandleFunc("POST /register", handler.Register)
+ mux.HandleFunc("POST /login", handler.Login)
+ mux.HandleFunc("POST /refresh", handler.Refresh)
+ mux.HandleFunc("GET /health", handler.Health)
+
+ addr := fmt.Sprintf(":%d", cfg.Server.Port)
+ srv := &http.Server{
+ Addr: addr,
+ Handler: mux,
+ ReadTimeout: 5 * time.Second,
+ WriteTimeout: 10 * time.Second,
+ IdleTimeout: 120 * time.Second,
+ }
+
+ // Graceful shutdown on SIGINT / SIGTERM
+ ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
+ defer stop()
+
+ go func() {
+ <-ctx.Done()
+ log.Println("shutting down...")
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ if err := srv.Shutdown(shutdownCtx); err != nil {
+ log.Printf("shutdown error: %v", err)
+ }
+ }()
+
+ log.Printf("listening on %s", addr)
+ if err := srv.ListenAndServe(); err != http.ErrServerClosed {
+ log.Fatalf("server: %v", err)
+ }
+
+ log.Println("stopped")
+}