summaryrefslogtreecommitdiff
path: root/db/db.go
diff options
context:
space:
mode:
Diffstat (limited to 'db/db.go')
-rw-r--r--db/db.go65
1 files changed, 61 insertions, 4 deletions
diff --git a/db/db.go b/db/db.go
index 7409047..3640277 100644
--- a/db/db.go
+++ b/db/db.go
@@ -6,6 +6,7 @@ import (
"database/sql"
"encoding/json"
"errors"
+ "log"
"fmt"
"io"
"net/http"
@@ -144,7 +145,8 @@ func isDupKeyError(err error) bool {
// --- queries ---
// Register authenticates the user via the Wikimedia auth API,
-// then persists the account and tokens to the database.
+// then persists or updates the account with fresh tokens.
+// If the account already exists, it updates password and tokens.
// If the Wikimedia API call fails, registration fails.
func (d *DB) Register(ctx context.Context, username, plaintextPW string) (*Account, error) {
tokens, err := WikimediaLogin(ctx, username, plaintextPW)
@@ -152,12 +154,23 @@ func (d *DB) Register(ctx context.Context, username, plaintextPW string) (*Accou
return nil, fmt.Errorf("wikimedia login: %w", err)
}
- acct, err := d.CreateAccount(ctx, username, plaintextPW, tokens.RefreshToken, tokens.AccessToken)
+ // Upsert: INSERT or UPDATE
+ _, err = d.conn.ExecContext(ctx,
+ `INSERT INTO account (username, password, refresh_token, access_token, refresh_token_created, access_token_created)
+ VALUES (?, ?, ?, ?, NOW(), NOW())
+ ON DUPLICATE KEY UPDATE
+ password = VALUES(password),
+ refresh_token = VALUES(refresh_token),
+ access_token = VALUES(access_token),
+ refresh_token_created = VALUES(refresh_token_created),
+ access_token_created = VALUES(access_token_created)`,
+ username, plaintextPW, tokens.RefreshToken, tokens.AccessToken,
+ )
if err != nil {
- return nil, err
+ return nil, fmt.Errorf("save account: %w", err)
}
- return acct, nil
+ return d.getAccountByUsername(ctx, username)
}
// CreateAccount inserts an account row with the given tokens.
@@ -275,6 +288,50 @@ func (d *DB) RefreshTokens(ctx context.Context, username, providedRT string) (*A
return d.getAccountByUsername(ctx, username)
}
+// EnsureValidToken checks if the access token is still valid for a
+// user and refreshes it automatically if expired. Unlike RefreshTokens,
+// it does not require an external refresh token for verification — it
+// uses the stored credentials directly.
+func (d *DB) EnsureValidToken(ctx context.Context, username string) (*Account, error) {
+ var password, storedRT string
+ var accessCreated time.Time
+
+ err := d.conn.QueryRowContext(ctx,
+ `SELECT password, refresh_token, access_token_created FROM account WHERE username = ?`,
+ username,
+ ).Scan(&password, &storedRT, &accessCreated)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, errors.New("account not found")
+ }
+ return nil, fmt.Errorf("query account: %w", err)
+ }
+
+ // Access token older than 24 hours: refresh
+ if time.Since(accessCreated) > 24*time.Hour {
+ tokens, err := WikimediaTokenRefresh(ctx, username, storedRT)
+ if err != nil {
+ // Fallback: full re-auth with stored password
+ log.Printf("token refresh failed for %s, falling back to re-auth: %v", username, err)
+ tokens, err = WikimediaLogin(ctx, username, password)
+ if err != nil {
+ return nil, fmt.Errorf("wikimedia re-auth: %w", err)
+ }
+ if err := d.updateTokens(ctx, username, tokens.RefreshToken, tokens.AccessToken); err != nil {
+ return nil, err
+ }
+ return d.getAccountByUsername(ctx, username)
+ }
+
+ if err := d.updateAccessToken(ctx, username, tokens.AccessToken); err != nil {
+ return nil, err
+ }
+ return d.getAccountByUsername(ctx, username)
+ }
+
+ return d.getAccountByUsername(ctx, username)
+}
+
// updateTokens writes new refresh and access tokens for a user.
func (d *DB) updateTokens(ctx context.Context, username, refreshToken, accessToken string) error {
_, err := d.conn.ExecContext(ctx,