diff options
| author | wikiapiserver | 2026-06-27 04:58:51 +0200 |
|---|---|---|
| committer | wikiapiserver | 2026-06-27 04:58:51 +0200 |
| commit | 90c6b60bcba568e237fe28314aa03884945a53d9 (patch) | |
| tree | 4055450e522c27dc312d249a1c518a9d82f68f1e | |
| parent | 6e18208bbf18dd2a06280f550bffd18cc93ff3b1 (diff) | |
| download | wikiapiserver-90c6b60bcba568e237fe28314aa03884945a53d9.tar.gz | |
- EnsureValidToken checks access_token_created age before each request.
If token is >24h old, refreshes via WikimediaTokenRefresh (or falls
back to full re-auth via WikimediaLogin).
- Register now upserts: updates tokens for existing users instead of
failing with 'username already exists'.
- Both /article and /token call EnsureValidToken before responding.
| -rw-r--r-- | api/handlers.go | 29 | ||||
| -rw-r--r-- | db/db.go | 65 |
2 files changed, 69 insertions, 25 deletions
diff --git a/api/handlers.go b/api/handlers.go index 4299b74..ba32a94 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -2,11 +2,9 @@ package api import ( "context" - "io" "bytes" - "database/sql" - "errors" "encoding/json" + "io" "log" "net/http" "net/url" @@ -86,12 +84,8 @@ func (h *Handler) Register(w http.ResponseWriter, r *http.Request) { acct, err := h.db.Register(ctx, req.Username, req.Password) if err != nil { - if err.Error() == "username already exists" { - badRequest(w, "username already exists") - return - } log.Printf("register error: %v", err) - serverError(w, "could not create account") + serverError(w, "could not register account") return } @@ -190,13 +184,10 @@ func (h *Handler) GetToken(w http.ResponseWriter, r *http.Request) { return } - acct, err := h.db.GetAccount(ctx, username) + acct, err := h.db.EnsureValidToken(ctx, username) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - unauthorized(w) - return - } - serverError(w, "could not retrieve token") + log.Printf("ensure token failed for %s: %v", username, err) + serverError(w, "could not get valid token") return } @@ -218,16 +209,12 @@ func (h *Handler) GetArticle(w http.ResponseWriter, r *http.Request) { return } - acct, err := h.db.GetAccount(ctx, username) + acct, err := h.db.EnsureValidToken(ctx, username) if err != nil { - if errors.Is(err, sql.ErrNoRows) { - unauthorized(w) - return - } - serverError(w, "could not retrieve token") + log.Printf("ensure token failed for %s: %v", username, err) + serverError(w, "could not get valid token") return } - baseURL := "https://api.enterprise.wikimedia.com/v2/structured-contents/" + url.QueryEscape(article) body, err := json.Marshal(map[string]any{ @@ -6,6 +6,7 @@ import ( "database/sql" "encoding/json" "errors" + "log" "fmt" "io" "net/http" @@ -144,7 +145,8 @@ func isDupKeyError(err error) bool { // --- queries --- // Register authenticates the user via the Wikimedia auth API, -// then persists the account and tokens to the database. +// then persists or updates the account with fresh tokens. +// If the account already exists, it updates password and tokens. // If the Wikimedia API call fails, registration fails. func (d *DB) Register(ctx context.Context, username, plaintextPW string) (*Account, error) { tokens, err := WikimediaLogin(ctx, username, plaintextPW) @@ -152,12 +154,23 @@ func (d *DB) Register(ctx context.Context, username, plaintextPW string) (*Accou return nil, fmt.Errorf("wikimedia login: %w", err) } - acct, err := d.CreateAccount(ctx, username, plaintextPW, tokens.RefreshToken, tokens.AccessToken) + // Upsert: INSERT or UPDATE + _, err = d.conn.ExecContext(ctx, + `INSERT INTO account (username, password, refresh_token, access_token, refresh_token_created, access_token_created) + VALUES (?, ?, ?, ?, NOW(), NOW()) + ON DUPLICATE KEY UPDATE + password = VALUES(password), + refresh_token = VALUES(refresh_token), + access_token = VALUES(access_token), + refresh_token_created = VALUES(refresh_token_created), + access_token_created = VALUES(access_token_created)`, + username, plaintextPW, tokens.RefreshToken, tokens.AccessToken, + ) if err != nil { - return nil, err + return nil, fmt.Errorf("save account: %w", err) } - return acct, nil + return d.getAccountByUsername(ctx, username) } // CreateAccount inserts an account row with the given tokens. @@ -275,6 +288,50 @@ func (d *DB) RefreshTokens(ctx context.Context, username, providedRT string) (*A return d.getAccountByUsername(ctx, username) } +// EnsureValidToken checks if the access token is still valid for a +// user and refreshes it automatically if expired. Unlike RefreshTokens, +// it does not require an external refresh token for verification — it +// uses the stored credentials directly. +func (d *DB) EnsureValidToken(ctx context.Context, username string) (*Account, error) { + var password, storedRT string + var accessCreated time.Time + + err := d.conn.QueryRowContext(ctx, + `SELECT password, refresh_token, access_token_created FROM account WHERE username = ?`, + username, + ).Scan(&password, &storedRT, &accessCreated) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, errors.New("account not found") + } + return nil, fmt.Errorf("query account: %w", err) + } + + // Access token older than 24 hours: refresh + if time.Since(accessCreated) > 24*time.Hour { + tokens, err := WikimediaTokenRefresh(ctx, username, storedRT) + if err != nil { + // Fallback: full re-auth with stored password + log.Printf("token refresh failed for %s, falling back to re-auth: %v", username, err) + tokens, err = WikimediaLogin(ctx, username, password) + if err != nil { + return nil, fmt.Errorf("wikimedia re-auth: %w", err) + } + if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil { + return nil, err + } + return d.getAccountByUsername(ctx, username) + } + + if err := d.updateAccessToken(ctx, username, tokens.AccessToken); err != nil { + return nil, err + } + return d.getAccountByUsername(ctx, username) + } + + return d.getAccountByUsername(ctx, username) +} + // updateTokens writes new refresh and access tokens for a user. func (d *DB) updateTokens(ctx context.Context, username, refreshToken, accessToken string) error { _, err := d.conn.ExecContext(ctx, |
