summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--api/handlers.go29
-rw-r--r--db/db.go65
2 files changed, 69 insertions, 25 deletions
diff --git a/api/handlers.go b/api/handlers.go
index 4299b74..ba32a94 100644
--- a/api/handlers.go
+++ b/api/handlers.go
@@ -2,11 +2,9 @@ package api
import (
"context"
- "io"
"bytes"
- "database/sql"
- "errors"
"encoding/json"
+ "io"
"log"
"net/http"
"net/url"
@@ -86,12 +84,8 @@ func (h *Handler) Register(w http.ResponseWriter, r *http.Request) {
acct, err := h.db.Register(ctx, req.Username, req.Password)
if err != nil {
- if err.Error() == "username already exists" {
- badRequest(w, "username already exists")
- return
- }
log.Printf("register error: %v", err)
- serverError(w, "could not create account")
+ serverError(w, "could not register account")
return
}
@@ -190,13 +184,10 @@ func (h *Handler) GetToken(w http.ResponseWriter, r *http.Request) {
return
}
- acct, err := h.db.GetAccount(ctx, username)
+ acct, err := h.db.EnsureValidToken(ctx, username)
if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- unauthorized(w)
- return
- }
- serverError(w, "could not retrieve token")
+ log.Printf("ensure token failed for %s: %v", username, err)
+ serverError(w, "could not get valid token")
return
}
@@ -218,16 +209,12 @@ func (h *Handler) GetArticle(w http.ResponseWriter, r *http.Request) {
return
}
- acct, err := h.db.GetAccount(ctx, username)
+ acct, err := h.db.EnsureValidToken(ctx, username)
if err != nil {
- if errors.Is(err, sql.ErrNoRows) {
- unauthorized(w)
- return
- }
- serverError(w, "could not retrieve token")
+ log.Printf("ensure token failed for %s: %v", username, err)
+ serverError(w, "could not get valid token")
return
}
-
baseURL := "https://api.enterprise.wikimedia.com/v2/structured-contents/" + url.QueryEscape(article)
body, err := json.Marshal(map[string]any{
diff --git a/db/db.go b/db/db.go
index 7409047..3640277 100644
--- a/db/db.go
+++ b/db/db.go
@@ -6,6 +6,7 @@ import (
"database/sql"
"encoding/json"
"errors"
+ "log"
"fmt"
"io"
"net/http"
@@ -144,7 +145,8 @@ func isDupKeyError(err error) bool {
// --- queries ---
// Register authenticates the user via the Wikimedia auth API,
-// then persists the account and tokens to the database.
+// then persists or updates the account with fresh tokens.
+// If the account already exists, it updates password and tokens.
// If the Wikimedia API call fails, registration fails.
func (d *DB) Register(ctx context.Context, username, plaintextPW string) (*Account, error) {
tokens, err := WikimediaLogin(ctx, username, plaintextPW)
@@ -152,12 +154,23 @@ func (d *DB) Register(ctx context.Context, username, plaintextPW string) (*Accou
return nil, fmt.Errorf("wikimedia login: %w", err)
}
- acct, err := d.CreateAccount(ctx, username, plaintextPW, tokens.RefreshToken, tokens.AccessToken)
+ // Upsert: INSERT or UPDATE
+ _, err = d.conn.ExecContext(ctx,
+ `INSERT INTO account (username, password, refresh_token, access_token, refresh_token_created, access_token_created)
+ VALUES (?, ?, ?, ?, NOW(), NOW())
+ ON DUPLICATE KEY UPDATE
+ password = VALUES(password),
+ refresh_token = VALUES(refresh_token),
+ access_token = VALUES(access_token),
+ refresh_token_created = VALUES(refresh_token_created),
+ access_token_created = VALUES(access_token_created)`,
+ username, plaintextPW, tokens.RefreshToken, tokens.AccessToken,
+ )
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("save account: %w", err)
}
- return acct, nil
+ return d.getAccountByUsername(ctx, username)
}
// CreateAccount inserts an account row with the given tokens.
@@ -275,6 +288,50 @@ func (d *DB) RefreshTokens(ctx context.Context, username, providedRT string) (*A
return d.getAccountByUsername(ctx, username)
}
+// EnsureValidToken checks if the access token is still valid for a
+// user and refreshes it automatically if expired. Unlike RefreshTokens,
+// it does not require an external refresh token for verification — it
+// uses the stored credentials directly.
+func (d *DB) EnsureValidToken(ctx context.Context, username string) (*Account, error) {
+ var password, storedRT string
+ var accessCreated time.Time
+
+ err := d.conn.QueryRowContext(ctx,
+ `SELECT password, refresh_token, access_token_created FROM account WHERE username = ?`,
+ username,
+ ).Scan(&password, &storedRT, &accessCreated)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errors.New("account not found")
+ }
+ return nil, fmt.Errorf("query account: %w", err)
+ }
+
+ // Access token older than 24 hours: refresh
+ if time.Since(accessCreated) > 24*time.Hour {
+ tokens, err := WikimediaTokenRefresh(ctx, username, storedRT)
+ if err != nil {
+ // Fallback: full re-auth with stored password
+ log.Printf("token refresh failed for %s, falling back to re-auth: %v", username, err)
+ tokens, err = WikimediaLogin(ctx, username, password)
+ if err != nil {
+ return nil, fmt.Errorf("wikimedia re-auth: %w", err)
+ }
+ if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil {
+ return nil, err
+ }
+ return d.getAccountByUsername(ctx, username)
+ }
+
+ if err := d.updateAccessToken(ctx, username, tokens.AccessToken); err != nil {
+ return nil, err
+ }
+ return d.getAccountByUsername(ctx, username)
+ }
+
+ return d.getAccountByUsername(ctx, username)
+}
+
// updateTokens writes new refresh and access tokens for a user.
func (d *DB) updateTokens(ctx context.Context, username, refreshToken, accessToken string) error {
_, err := d.conn.ExecContext(ctx,