From 90c6b60bcba568e237fe28314aa03884945a53d9 Mon Sep 17 00:00:00 2001 From: wikiapiserver Date: Sat, 27 Jun 2026 04:58:51 +0200 Subject: feat: auto-refresh expired tokens before /article and /token - EnsureValidToken checks access_token_created age before each request. If token is >24h old, refreshes via WikimediaTokenRefresh (or falls back to full re-auth via WikimediaLogin). - Register now upserts: updates tokens for existing users instead of failing with 'username already exists'. - Both /article and /token call EnsureValidToken before responding. --- api/handlers.go | 29 +++++++------------------ db/db.go | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 69 insertions(+), 25 deletions(-) diff --git a/api/handlers.go b/api/handlers.go index 4299b74..ba32a94 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -2,11 +2,9 @@ package api import ( "context" - "io" "bytes" - "database/sql" - "errors" "encoding/json" + "io" "log" "net/http" "net/url" @@ -86,12 +84,8 @@ func (h *Handler) Register(w http.ResponseWriter, r *http.Request) { acct, err := h.db.Register(ctx, req.Username, req.Password) if err != nil { - if err.Error() == "username already exists" { - badRequest(w, "username already exists") - return - } log.Printf("register error: %v", err) - serverError(w, "could not create account") + serverError(w, "could not register account") return } @@ -190,13 +184,10 @@ func (h *Handler) GetToken(w http.ResponseWriter, r *http.Request) { return } - acct, err := h.db.GetAccount(ctx, username) + acct, err := h.db.EnsureValidToken(ctx, username) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - unauthorized(w) - return - } - serverError(w, "could not retrieve token") + log.Printf("ensure token failed for %s: %v", username, err) + serverError(w, "could not get valid token") return } @@ -218,16 +209,12 @@ func (h *Handler) GetArticle(w http.ResponseWriter, r *http.Request) { return } - acct, err := h.db.GetAccount(ctx, username) + acct, err := h.db.EnsureValidToken(ctx, username) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - unauthorized(w) - return - } - serverError(w, "could not retrieve token") + log.Printf("ensure token failed for %s: %v", username, err) + serverError(w, "could not get valid token") return } - baseURL := "https://api.enterprise.wikimedia.com/v2/structured-contents/" + url.QueryEscape(article) body, err := json.Marshal(map[string]any{ diff --git a/db/db.go b/db/db.go index 7409047..3640277 100644 --- a/db/db.go +++ b/db/db.go @@ -6,6 +6,7 @@ import ( "database/sql" "encoding/json" "errors" + "log" "fmt" "io" "net/http" @@ -144,7 +145,8 @@ func isDupKeyError(err error) bool { // --- queries --- // Register authenticates the user via the Wikimedia auth API, -// then persists the account and tokens to the database. +// then persists or updates the account with fresh tokens. +// If the account already exists, it updates password and tokens. // If the Wikimedia API call fails, registration fails. func (d *DB) Register(ctx context.Context, username, plaintextPW string) (*Account, error) { tokens, err := WikimediaLogin(ctx, username, plaintextPW) @@ -152,12 +154,23 @@ func (d *DB) Register(ctx context.Context, username, plaintextPW string) (*Accou return nil, fmt.Errorf("wikimedia login: %w", err) } - acct, err := d.CreateAccount(ctx, username, plaintextPW, tokens.RefreshToken, tokens.AccessToken) + // Upsert: INSERT or UPDATE + _, err = d.conn.ExecContext(ctx, + `INSERT INTO account (username, password, refresh_token, access_token, refresh_token_created, access_token_created) + VALUES (?, ?, ?, ?, NOW(), NOW()) + ON DUPLICATE KEY UPDATE + password = VALUES(password), + refresh_token = VALUES(refresh_token), + access_token = VALUES(access_token), + refresh_token_created = VALUES(refresh_token_created), + access_token_created = VALUES(access_token_created)`, + username, plaintextPW, tokens.RefreshToken, tokens.AccessToken, + ) if err != nil { - return nil, err + return nil, fmt.Errorf("save account: %w", err) } - return acct, nil + return d.getAccountByUsername(ctx, username) } // CreateAccount inserts an account row with the given tokens. @@ -275,6 +288,50 @@ func (d *DB) RefreshTokens(ctx context.Context, username, providedRT string) (*A return d.getAccountByUsername(ctx, username) } +// EnsureValidToken checks if the access token is still valid for a +// user and refreshes it automatically if expired. Unlike RefreshTokens, +// it does not require an external refresh token for verification — it +// uses the stored credentials directly. +func (d *DB) EnsureValidToken(ctx context.Context, username string) (*Account, error) { + var password, storedRT string + var accessCreated time.Time + + err := d.conn.QueryRowContext(ctx, + `SELECT password, refresh_token, access_token_created FROM account WHERE username = ?`, + username, + ).Scan(&password, &storedRT, &accessCreated) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errors.New("account not found") + } + return nil, fmt.Errorf("query account: %w", err) + } + + // Access token older than 24 hours: refresh + if time.Since(accessCreated) > 24*time.Hour { + tokens, err := WikimediaTokenRefresh(ctx, username, storedRT) + if err != nil { + // Fallback: full re-auth with stored password + log.Printf("token refresh failed for %s, falling back to re-auth: %v", username, err) + tokens, err = WikimediaLogin(ctx, username, password) + if err != nil { + return nil, fmt.Errorf("wikimedia re-auth: %w", err) + } + if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil { + return nil, err + } + return d.getAccountByUsername(ctx, username) + } + + if err := d.updateAccessToken(ctx, username, tokens.AccessToken); err != nil { + return nil, err + } + return d.getAccountByUsername(ctx, username) + } + + return d.getAccountByUsername(ctx, username) +} + // updateTokens writes new refresh and access tokens for a user. func (d *DB) updateTokens(ctx context.Context, username, refreshToken, accessToken string) error { _, err := d.conn.ExecContext(ctx, -- cgit v1.2.3