summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwikiapiserver2026-06-25 14:47:35 +0200
committerwikiapiserver2026-06-25 14:47:35 +0200
commitcc960860e4109b4eb50721d0b3338df4b859d559 (patch)
tree666e75656092814461a5dc58fdc6b64c3677e390
parente375b6cc68f4a9b0e91b25479538dc76d1f1e620 (diff)
downloadwikiapiserver-cc960860e4109b4eb50721d0b3338df4b859d559.tar.gz
feat: token refresh with age-based logic
- RefreshTokens checks token age and chooses the right path: - refresh_token > 90 days: re-auth via WikimediaLogin (full login) - access_token > 24 hours: refresh via WikimediaTokenRefresh - otherwise: return current tokens - WikimediaTokenRefresh posts to /v1/token-refresh endpoint - Login also uses WikimediaLogin instead of local RotateTokens - Removed dead RotateTokens, RefreshByToken, and randomHex - DSN includes parseTime=true for timestamp columns
-rw-r--r--api/handlers.go20
-rw-r--r--db/db.go191
-rw-r--r--main.go2
3 files changed, 141 insertions, 72 deletions
diff --git a/api/handlers.go b/api/handlers.go
index 7918b40..f98dd6b 100644
--- a/api/handlers.go
+++ b/api/handlers.go
@@ -38,10 +38,6 @@ type loginReq struct {
Password string `json:"password"`
}
-type refreshReq struct {
- RefreshToken string `json:"refresh_token"`
-}
-
// --- helper writers ---
func writeJSON(w http.ResponseWriter, code int, v any) {
@@ -124,6 +120,12 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) {
}
// --- Refresh: POST /refresh ---
+// Accepts username and refresh_token. The refresh_token is used to
+// verify identity; RefreshTokens handles the age-based logic.
+type refreshReq struct {
+ Username string `json:"username"`
+ RefreshToken string `json:"refresh_token"`
+}
func (h *Handler) Refresh(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), defaultTimeout)
@@ -135,17 +137,17 @@ func (h *Handler) Refresh(w http.ResponseWriter, r *http.Request) {
return
}
- if req.RefreshToken == "" {
- badRequest(w, "refresh_token is required")
+ if req.Username == "" || req.RefreshToken == "" {
+ badRequest(w, "username and refresh_token are required")
return
}
-
- acct, err := h.db.RefreshByToken(ctx, req.RefreshToken)
+ acct, err := h.db.RefreshTokens(ctx, req.Username, req.RefreshToken)
if err != nil {
- if err.Error() == "invalid refresh token" {
+ if err.Error() == "account not found" || err.Error() == "invalid refresh token" {
unauthorized(w)
return
}
+ log.Printf("refresh error: %v", err)
serverError(w, "could not refresh token")
return
}
diff --git a/db/db.go b/db/db.go
index ce0dbfd..6fed1b0 100644
--- a/db/db.go
+++ b/db/db.go
@@ -2,9 +2,8 @@ package db
import (
"context"
- "crypto/rand"
+ "bytes"
"database/sql"
- "encoding/hex"
"encoding/json"
"errors"
"fmt"
@@ -17,10 +16,7 @@ import (
_ "github.com/go-sql-driver/mysql"
)
-const (
- tokenLength = 32
- accessTokenTTL = 24 * time.Hour
-)
+const accessTokenTTL = 24 * time.Hour
// Account represents a verified account with its current tokens.
type Account struct {
@@ -62,16 +58,6 @@ func (d *DB) Ping(ctx context.Context) error {
return d.conn.PingContext(ctx)
}
-// --- crypto helpers ---
-
-func randomHex(n int) (string, error) {
- b := make([]byte, n)
- if _, err := rand.Read(b); err != nil {
- return "", fmt.Errorf("crypto rand: %w", err)
- }
- return hex.EncodeToString(b), nil
-}
-
// WikimediaTokens holds the tokens returned by the Wikimedia auth API.
type WikimediaTokens struct {
RefreshToken string `json:"refresh_token"`
@@ -111,6 +97,44 @@ func WikimediaLogin(ctx context.Context, username, password string) (*WikimediaT
return &tokens, nil
}
+// WikimediaTokenRefresh posts to the Wikimedia token-refresh endpoint
+// and returns new tokens from the response.
+func WikimediaTokenRefresh(ctx context.Context, username, refreshToken string) (*WikimediaTokens, error) {
+ body, err := json.Marshal(map[string]string{
+ "username": username,
+ "refresh_token": refreshToken,
+ })
+ if err != nil {
+ return nil, fmt.Errorf("marshal token refresh body: %w", err)
+ }
+
+ req, err := http.NewRequestWithContext(ctx, "POST",
+ "https://auth.enterprise.wikimedia.com/v1/token-refresh",
+ bytes.NewReader(body))
+ if err != nil {
+ return nil, fmt.Errorf("new request: %w", err)
+ }
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ return nil, fmt.Errorf("token refresh request: %w", err)
+ }
+ defer resp.Body.Close()
+
+ if resp.StatusCode < 200 || resp.StatusCode >= 300 {
+ bs, _ := io.ReadAll(resp.Body)
+ return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(bs))
+ }
+
+ var tokens WikimediaTokens
+ if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil {
+ return nil, fmt.Errorf("decode token refresh response: %w", err)
+ }
+
+ return &tokens, nil
+}
+
// --- error helpers ---
func isDupKeyError(err error) bool {
@@ -166,7 +190,8 @@ func (d *DB) CreateAccount(ctx context.Context, username, plaintextPW, refreshTo
}, nil
}
-// Authenticate verifies plaintext credentials and returns fresh tokens.
+// Authenticate verifies plaintext credentials and returns an account
+// with fresh tokens obtained from the Wikimedia auth API.
func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*Account, error) {
var storedPW string
err := d.conn.QueryRowContext(ctx,
@@ -183,69 +208,111 @@ func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*A
return nil, errors.New("invalid credentials")
}
- // Look up account id, then rotate tokens
- var id int64
- if err := d.conn.QueryRowContext(ctx,
- `SELECT id FROM account WHERE username = ?`, username,
- ).Scan(&id); err != nil {
- return nil, fmt.Errorf("lookup id: %w", err)
+ tokens, err := WikimediaLogin(ctx, username, plaintextPW)
+ if err != nil {
+ return nil, fmt.Errorf("wikimedia login: %w", err)
}
- return d.RotateTokens(ctx, id)
+ if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil {
+ return nil, err
+ }
+
+ return d.getAccountByUsername(ctx, username)
}
-// RotateTokens generates new refresh + access tokens for an account.
-func (d *DB) RotateTokens(ctx context.Context, id int64) (*Account, error) {
- rt, err := randomHex(tokenLength)
+// RefreshTokens checks the age of the stored tokens and refreshes
+// them via the Wikimedia auth API as needed:
+// - refresh_token older than 90 days: full re-auth via WikimediaLogin
+// - access_token older than 24 hours: token refresh via WikimediaTokenRefresh
+func (d *DB) RefreshTokens(ctx context.Context, username, providedRT string) (*Account, error) {
+ var id int64
+ var password, storedRT, accessToken string
+ var refreshCreated, accessCreated time.Time
+
+ err := d.conn.QueryRowContext(ctx,
+ `SELECT id, password, refresh_token, access_token, refresh_token_created, access_token_created
+ FROM account WHERE username = ?`, username,
+ ).Scan(&id, &password, &storedRT, &accessToken, &refreshCreated, &accessCreated)
if err != nil {
- return nil, err
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errors.New("account not found")
+ }
+ return nil, fmt.Errorf("query account: %w", err)
}
- at, err := randomHex(tokenLength)
- if err != nil {
- return nil, err
+
+ // Verify the provided refresh token matches
+ if storedRT != providedRT {
+ return nil, errors.New("invalid refresh token")
}
- res, err := d.conn.ExecContext(ctx,
- `UPDATE account SET refresh_token = ?, access_token = ?, refresh_token_created = NOW(), access_token_created = NOW()
- WHERE id = ?`, rt, at, id)
- if err != nil {
- return nil, fmt.Errorf("rotate tokens: %w", err)
+ // Refresh token older than 90 days: re-authenticate
+ if time.Since(refreshCreated) > 90*24*time.Hour {
+ tokens, err := WikimediaLogin(ctx, username, password)
+ if err != nil {
+ return nil, fmt.Errorf("wikimedia re-auth: %w", err)
+ }
+ if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil {
+ return nil, err
+ }
+ return d.getAccountByUsername(ctx, username)
}
- if affected, _ := res.RowsAffected(); affected == 0 {
- return nil, errors.New("account not found")
+
+ // Access token older than 24 hours: refresh via token-refresh endpoint
+ if time.Since(accessCreated) > 24*time.Hour {
+ tokens, err := WikimediaTokenRefresh(ctx, username, storedRT)
+ if err != nil {
+ return nil, fmt.Errorf("wikimedia token refresh: %w", err)
+ }
+
+ // Update both tokens (refresh response may include a new refresh token)
+ if err := d.updateAccessToken(ctx, username, tokens.AccessToken); err != nil {
+ return nil, err
+ }
+ return d.getAccountByUsername(ctx, username)
}
- var username string
- if err := d.conn.QueryRowContext(ctx,
- `SELECT username FROM account WHERE id = ?`, id,
- ).Scan(&username); err != nil {
- return nil, fmt.Errorf("lookup username: %w", err)
+ // Tokens are still valid
+ return d.getAccountByUsername(ctx, username)
+}
+
+// updateTokens writes new refresh and access tokens for a user.
+func (d *DB) updateTokens(ctx context.Context, username, refreshToken, accessToken string) error {
+ _, err := d.conn.ExecContext(ctx,
+ `UPDATE account SET refresh_token = ?, access_token = ?, refresh_token_created = NOW(), access_token_created = NOW()
+ WHERE username = ?`, refreshToken, accessToken, username)
+ if err != nil {
+ return fmt.Errorf("update tokens: %w", err)
}
+ return nil
+}
- now := time.Now()
- return &Account{
- ID: id,
- Username: username,
- RefreshToken: rt,
- AccessToken: at,
- AccessTokenExpiry: now.Add(accessTokenTTL),
- CreatedAt: now,
- }, nil
+// updateAccessToken writes a new access token for a user.
+func (d *DB) updateAccessToken(ctx context.Context, username, accessToken string) error {
+ _, err := d.conn.ExecContext(ctx,
+ `UPDATE account SET access_token = ?, access_token_created = NOW()
+ WHERE username = ?`, accessToken, username)
+ if err != nil {
+ return fmt.Errorf("update access token: %w", err)
+ }
+ return nil
}
-// RefreshByToken looks up an account by its refresh token and rotates both tokens.
-func (d *DB) RefreshByToken(ctx context.Context, refreshToken string) (*Account, error) {
- var id int64
+// getAccountByUsername fetches the current account state.
+func (d *DB) getAccountByUsername(ctx context.Context, username string) (*Account, error) {
+ var acct Account
+ var refreshCreated, accessCreated time.Time
+
err := d.conn.QueryRowContext(ctx,
- `SELECT id FROM account WHERE refresh_token = ?`, refreshToken,
- ).Scan(&id)
+ `SELECT id, username, refresh_token, access_token, refresh_token_created, access_token_created, created_at
+ FROM account WHERE username = ?`, username,
+ ).Scan(&acct.ID, &acct.Username, &acct.RefreshToken, &acct.AccessToken,
+ &refreshCreated, &accessCreated, &acct.CreatedAt)
if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- return nil, errors.New("invalid refresh token")
- }
- return nil, fmt.Errorf("lookup refresh token: %w", err)
+ return nil, fmt.Errorf("get account: %w", err)
}
- return d.RotateTokens(ctx, id)
+
+ acct.AccessTokenExpiry = accessCreated
+ return &acct, nil
}
// HealthCheck runs a trivial query to verify DB liveness.
diff --git a/main.go b/main.go
index 813ac41..6662d90 100644
--- a/main.go
+++ b/main.go
@@ -59,7 +59,7 @@ func loadConfig() (*Config, error) {
}
func buildDSN(cfg *Config) string {
- return fmt.Sprintf("%s:%s@tcp(%s)/%s",
+ return fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true",
cfg.Database.Username,
cfg.Database.Password,
cfg.Database.Host,