diff options
| author | wikiapiserver | 2026-06-25 14:47:35 +0200 |
|---|---|---|
| committer | wikiapiserver | 2026-06-25 14:47:35 +0200 |
| commit | cc960860e4109b4eb50721d0b3338df4b859d559 (patch) | |
| tree | 666e75656092814461a5dc58fdc6b64c3677e390 | |
| parent | e375b6cc68f4a9b0e91b25479538dc76d1f1e620 (diff) | |
| download | wikiapiserver-cc960860e4109b4eb50721d0b3338df4b859d559.tar.gz | |
feat: token refresh with age-based logic
- RefreshTokens checks token age and chooses the right path:
- refresh_token > 90 days: re-auth via WikimediaLogin (full login)
- access_token > 24 hours: refresh via WikimediaTokenRefresh
- otherwise: return current tokens
- WikimediaTokenRefresh posts to /v1/token-refresh endpoint
- Login also uses WikimediaLogin instead of local RotateTokens
- Removed dead RotateTokens, RefreshByToken, and randomHex
- DSN includes parseTime=true for timestamp columns
| -rw-r--r-- | api/handlers.go | 20 | ||||
| -rw-r--r-- | db/db.go | 191 | ||||
| -rw-r--r-- | main.go | 2 |
3 files changed, 141 insertions, 72 deletions
diff --git a/api/handlers.go b/api/handlers.go index 7918b40..f98dd6b 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -38,10 +38,6 @@ type loginReq struct { Password string `json:"password"` } -type refreshReq struct { - RefreshToken string `json:"refresh_token"` -} - // --- helper writers --- func writeJSON(w http.ResponseWriter, code int, v any) { @@ -124,6 +120,12 @@ func (h *Handler) Login(w http.ResponseWriter, r *http.Request) { } // --- Refresh: POST /refresh --- +// Accepts username and refresh_token. The refresh_token is used to +// verify identity; RefreshTokens handles the age-based logic. +type refreshReq struct { + Username string `json:"username"` + RefreshToken string `json:"refresh_token"` +} func (h *Handler) Refresh(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), defaultTimeout) @@ -135,17 +137,17 @@ func (h *Handler) Refresh(w http.ResponseWriter, r *http.Request) { return } - if req.RefreshToken == "" { - badRequest(w, "refresh_token is required") + if req.Username == "" || req.RefreshToken == "" { + badRequest(w, "username and refresh_token are required") return } - - acct, err := h.db.RefreshByToken(ctx, req.RefreshToken) + acct, err := h.db.RefreshTokens(ctx, req.Username, req.RefreshToken) if err != nil { - if err.Error() == "invalid refresh token" { + if err.Error() == "account not found" || err.Error() == "invalid refresh token" { unauthorized(w) return } + log.Printf("refresh error: %v", err) serverError(w, "could not refresh token") return } @@ -2,9 +2,8 @@ package db import ( "context" - "crypto/rand" + "bytes" "database/sql" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -17,10 +16,7 @@ import ( _ "github.com/go-sql-driver/mysql" ) -const ( - tokenLength = 32 - accessTokenTTL = 24 * time.Hour -) +const accessTokenTTL = 24 * time.Hour // Account represents a verified account with its current tokens. type Account struct { @@ -62,16 +58,6 @@ func (d *DB) Ping(ctx context.Context) error { return d.conn.PingContext(ctx) } -// --- crypto helpers --- - -func randomHex(n int) (string, error) { - b := make([]byte, n) - if _, err := rand.Read(b); err != nil { - return "", fmt.Errorf("crypto rand: %w", err) - } - return hex.EncodeToString(b), nil -} - // WikimediaTokens holds the tokens returned by the Wikimedia auth API. type WikimediaTokens struct { RefreshToken string `json:"refresh_token"` @@ -111,6 +97,44 @@ func WikimediaLogin(ctx context.Context, username, password string) (*WikimediaT return &tokens, nil } +// WikimediaTokenRefresh posts to the Wikimedia token-refresh endpoint +// and returns new tokens from the response. +func WikimediaTokenRefresh(ctx context.Context, username, refreshToken string) (*WikimediaTokens, error) { + body, err := json.Marshal(map[string]string{ + "username": username, + "refresh_token": refreshToken, + }) + if err != nil { + return nil, fmt.Errorf("marshal token refresh body: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", + "https://auth.enterprise.wikimedia.com/v1/token-refresh", + bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("new request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return nil, fmt.Errorf("token refresh request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + bs, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("token refresh failed (status %d): %s", resp.StatusCode, string(bs)) + } + + var tokens WikimediaTokens + if err := json.NewDecoder(resp.Body).Decode(&tokens); err != nil { + return nil, fmt.Errorf("decode token refresh response: %w", err) + } + + return &tokens, nil +} + // --- error helpers --- func isDupKeyError(err error) bool { @@ -166,7 +190,8 @@ func (d *DB) CreateAccount(ctx context.Context, username, plaintextPW, refreshTo }, nil } -// Authenticate verifies plaintext credentials and returns fresh tokens. +// Authenticate verifies plaintext credentials and returns an account +// with fresh tokens obtained from the Wikimedia auth API. func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*Account, error) { var storedPW string err := d.conn.QueryRowContext(ctx, @@ -183,69 +208,111 @@ func (d *DB) Authenticate(ctx context.Context, username, plaintextPW string) (*A return nil, errors.New("invalid credentials") } - // Look up account id, then rotate tokens - var id int64 - if err := d.conn.QueryRowContext(ctx, - `SELECT id FROM account WHERE username = ?`, username, - ).Scan(&id); err != nil { - return nil, fmt.Errorf("lookup id: %w", err) + tokens, err := WikimediaLogin(ctx, username, plaintextPW) + if err != nil { + return nil, fmt.Errorf("wikimedia login: %w", err) } - return d.RotateTokens(ctx, id) + if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil { + return nil, err + } + + return d.getAccountByUsername(ctx, username) } -// RotateTokens generates new refresh + access tokens for an account. -func (d *DB) RotateTokens(ctx context.Context, id int64) (*Account, error) { - rt, err := randomHex(tokenLength) +// RefreshTokens checks the age of the stored tokens and refreshes +// them via the Wikimedia auth API as needed: +// - refresh_token older than 90 days: full re-auth via WikimediaLogin +// - access_token older than 24 hours: token refresh via WikimediaTokenRefresh +func (d *DB) RefreshTokens(ctx context.Context, username, providedRT string) (*Account, error) { + var id int64 + var password, storedRT, accessToken string + var refreshCreated, accessCreated time.Time + + err := d.conn.QueryRowContext(ctx, + `SELECT id, password, refresh_token, access_token, refresh_token_created, access_token_created + FROM account WHERE username = ?`, username, + ).Scan(&id, &password, &storedRT, &accessToken, &refreshCreated, &accessCreated) if err != nil { - return nil, err + if errors.Is(err, sql.ErrNoRows) { + return nil, errors.New("account not found") + } + return nil, fmt.Errorf("query account: %w", err) } - at, err := randomHex(tokenLength) - if err != nil { - return nil, err + + // Verify the provided refresh token matches + if storedRT != providedRT { + return nil, errors.New("invalid refresh token") } - res, err := d.conn.ExecContext(ctx, - `UPDATE account SET refresh_token = ?, access_token = ?, refresh_token_created = NOW(), access_token_created = NOW() - WHERE id = ?`, rt, at, id) - if err != nil { - return nil, fmt.Errorf("rotate tokens: %w", err) + // Refresh token older than 90 days: re-authenticate + if time.Since(refreshCreated) > 90*24*time.Hour { + tokens, err := WikimediaLogin(ctx, username, password) + if err != nil { + return nil, fmt.Errorf("wikimedia re-auth: %w", err) + } + if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil { + return nil, err + } + return d.getAccountByUsername(ctx, username) } - if affected, _ := res.RowsAffected(); affected == 0 { - return nil, errors.New("account not found") + + // Access token older than 24 hours: refresh via token-refresh endpoint + if time.Since(accessCreated) > 24*time.Hour { + tokens, err := WikimediaTokenRefresh(ctx, username, storedRT) + if err != nil { + return nil, fmt.Errorf("wikimedia token refresh: %w", err) + } + + // Update both tokens (refresh response may include a new refresh token) + if err := d.updateAccessToken(ctx, username, tokens.AccessToken); err != nil { + return nil, err + } + return d.getAccountByUsername(ctx, username) } - var username string - if err := d.conn.QueryRowContext(ctx, - `SELECT username FROM account WHERE id = ?`, id, - ).Scan(&username); err != nil { - return nil, fmt.Errorf("lookup username: %w", err) + // Tokens are still valid + return d.getAccountByUsername(ctx, username) +} + +// updateTokens writes new refresh and access tokens for a user. +func (d *DB) updateTokens(ctx context.Context, username, refreshToken, accessToken string) error { + _, err := d.conn.ExecContext(ctx, + `UPDATE account SET refresh_token = ?, access_token = ?, refresh_token_created = NOW(), access_token_created = NOW() + WHERE username = ?`, refreshToken, accessToken, username) + if err != nil { + return fmt.Errorf("update tokens: %w", err) } + return nil +} - now := time.Now() - return &Account{ - ID: id, - Username: username, - RefreshToken: rt, - AccessToken: at, - AccessTokenExpiry: now.Add(accessTokenTTL), - CreatedAt: now, - }, nil +// updateAccessToken writes a new access token for a user. +func (d *DB) updateAccessToken(ctx context.Context, username, accessToken string) error { + _, err := d.conn.ExecContext(ctx, + `UPDATE account SET access_token = ?, access_token_created = NOW() + WHERE username = ?`, accessToken, username) + if err != nil { + return fmt.Errorf("update access token: %w", err) + } + return nil } -// RefreshByToken looks up an account by its refresh token and rotates both tokens. -func (d *DB) RefreshByToken(ctx context.Context, refreshToken string) (*Account, error) { - var id int64 +// getAccountByUsername fetches the current account state. +func (d *DB) getAccountByUsername(ctx context.Context, username string) (*Account, error) { + var acct Account + var refreshCreated, accessCreated time.Time + err := d.conn.QueryRowContext(ctx, - `SELECT id FROM account WHERE refresh_token = ?`, refreshToken, - ).Scan(&id) + `SELECT id, username, refresh_token, access_token, refresh_token_created, access_token_created, created_at + FROM account WHERE username = ?`, username, + ).Scan(&acct.ID, &acct.Username, &acct.RefreshToken, &acct.AccessToken, + &refreshCreated, &accessCreated, &acct.CreatedAt) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, errors.New("invalid refresh token") - } - return nil, fmt.Errorf("lookup refresh token: %w", err) + return nil, fmt.Errorf("get account: %w", err) } - return d.RotateTokens(ctx, id) + + acct.AccessTokenExpiry = accessCreated + return &acct, nil } // HealthCheck runs a trivial query to verify DB liveness. @@ -59,7 +59,7 @@ func loadConfig() (*Config, error) { } func buildDSN(cfg *Config) string { - return fmt.Sprintf("%s:%s@tcp(%s)/%s", + return fmt.Sprintf("%s:%s@tcp(%s)/%s?parseTime=true", cfg.Database.Username, cfg.Database.Password, cfg.Database.Host, |
