diff --git a/src/main.rs b/src/main.rs index 951c72d..2c337da 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,7 +9,7 @@ use axum::{ }; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; -use sqlx::{postgres::PgPoolOptions, PgPool}; +use sqlx::{postgres::PgPoolOptions, PgExecutor, PgPool}; use std::env; use std::time::{Duration, SystemTime, UNIX_EPOCH}; @@ -92,10 +92,9 @@ async fn create_user( .await .unwrap(); - let does_exist: bool = if row.0 == 1 { true } else { false }; + let does_exist: bool = if row.0 >= 1 { true } else { false }; if does_exist == false { - //let _: () = con.set(&user_key, password).unwrap(); sqlx::query("INSERT INTO users (username, password) VALUES ($1, $2)") .bind(&username) .bind(&password) @@ -112,48 +111,51 @@ async fn create_user( (StatusCode::CREATED, Json(json!({ "username" : username }))) } -fn authorize(username: &str, password: &str) -> bool { - /*let client = redis::Client::open("redis://127.0.0.1/").unwrap(); - let mut con = client.get_connection().unwrap(); - +async fn authorize(db: impl PgExecutor<'_>, username: &str, password: &str) -> bool { if username.is_empty() || password.is_empty() { return false; } - let user_key = format!("user:{username}:key"); + let row: Option<(String,)> = sqlx::query_as("SELECT password FROM users WHERE username = $1") + .bind(&username) + .fetch_optional(db) + .await + .unwrap(); - let redis_pw: String = con.get(&user_key).unwrap(); + if let Some(val) = row { + return password == val.0; + } - if password != redis_pw { - return false; - } */ - - true + false } -async fn auth_user(headers: HeaderMap) -> (StatusCode, Json) { +async fn auth_user(State(db_pool): State, headers: HeaderMap) -> (StatusCode, Json) { let username = headers["x-auth-user"].to_str().unwrap_or(""); let password = headers["x-auth-key"].to_str().unwrap_or(""); - if authorize(&username, &password) == false { - return ( - StatusCode::UNAUTHORIZED, - Json(json!({"message" : "Unauthorized"})), - ); + let mut tx = db_pool.begin().await.unwrap(); + + if authorize(&mut *tx, username, password).await == true { + return (StatusCode::OK, Json(json!({"authorized" : "OK"}))); } - (StatusCode::OK, Json(json!({"authorized" : "OK"}))) + tx.commit().await.unwrap(); + + ( + StatusCode::UNAUTHORIZED, + Json(json!({"message" : "Unauthorized"})), + ) } async fn update_progress(headers: HeaderMap, Json(payload): Json) -> StatusCode { let username = headers["x-auth-user"].to_str().unwrap_or(""); let password = headers["x-auth-key"].to_str().unwrap_or(""); - if authorize(username, password) == false { + /*if authorize(username, password) == false { return StatusCode::UNAUTHORIZED; } - /*let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); let mut con = client.get_connection().unwrap(); let timestamp = SystemTime::now() @@ -186,11 +188,11 @@ async fn get_progress( let username = headers["x-auth-user"].to_str().unwrap_or(""); let password = headers["x-auth-key"].to_str().unwrap_or(""); - if authorize(username, password) == false { + /* if authorize(username, password) == false { return (StatusCode::UNAUTHORIZED, Json(json!(""))); } - /* let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); let mut con = client.get_connection().unwrap(); let doc_key = format!("user:{username}:document:{document}");