use std::collections::HashSet; use async_trait::async_trait; use axum_login::{AuthUser, AuthzBackend}; use axum_login::{AuthnBackend, UserId}; use log::debug; use password_auth::{generate_hash, is_hash_obsolete, verify_password}; use sea_orm::{ ActiveModelTrait, ActiveValue, ColumnTrait, DatabaseConnection, EntityTrait, ModelTrait, QueryFilter, }; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::task; use crate::utils::get_current_timestamp; #[derive(Debug, Clone)] pub struct UserRepository { pub(crate) connection: DatabaseConnection, } impl UserRepository { pub fn new(connection: DatabaseConnection) -> Self { UserRepository { connection: connection, } } pub(crate) fn encode_password(password: String) -> String { // This function will try to avoid re-encoding an encoded password. // This is why it is not public outside of the crate. if let Ok(_is_obsolete) = is_hash_obsolete(password.as_str()) { // Not sure what to do if it is obsolete. if _is_obsolete { debug!("UserRepository::encode_password: found obsolete password hash."); } return password; } // As checking for obsoleteness errored out, we assume this a raw password. generate_hash(password) } pub fn new_user(username: &str, password: &str) -> entity::user::ActiveModel { entity::user::ActiveModel { username: sea_orm::ActiveValue::Set(username.to_owned()), password: sea_orm::ActiveValue::Set(UserRepository::encode_password( password.to_owned(), )), date_joined: sea_orm::ActiveValue::Set(Some(get_current_timestamp())), ..::default() } } } #[derive(Clone, Serialize, Deserialize)] pub struct AuthenticatedUser { id: i64, pub username: String, password: String, session_token: Vec, // Stores the hash } impl AuthenticatedUser { fn new(id: i64, username: String, password: String) -> Self { AuthenticatedUser { id: id, username: username, session_token: blake3::hash(&password.as_bytes()).as_bytes().to_vec(), password: password, } } fn set_password(&mut self, new_password: String) { self.password = new_password; self.session_token = blake3::hash(self.password.as_bytes()).as_bytes().to_vec(); } } // Here we've implemented `Debug` manually to avoid accidentally logging the // password hash. impl std::fmt::Debug for AuthenticatedUser { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("User") .field("id", &self.id) .field("username", &self.username) .field("password", &"[redacted]") .finish() } } impl AuthUser for AuthenticatedUser { type Id = i64; fn id(&self) -> Self::Id { self.id } fn session_auth_hash(&self) -> &[u8] { &self.session_token } } impl From for AuthenticatedUser { fn from(value: entity::user::Model) -> Self { AuthenticatedUser::new(value.id, value.username, value.password) } } // This allows us to extract the authentication fields from forms. We use this // to authenticate requests with the backend. #[derive(Debug, Clone, Deserialize)] pub struct Credentials { pub username: String, pub password: String, pub next: Option, } #[derive(Debug, Error)] pub struct NotFoundError { pub details: String, } impl NotFoundError { pub fn new(details: &str) -> Self { NotFoundError { details: details.to_string(), } } } impl std::fmt::Display for NotFoundError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "Not Found... Error: {}", self.details) } } #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] DbErr(#[from] sea_orm::DbErr), #[error("Not Found: {0}")] NotFound(#[from] NotFoundError), #[error(transparent)] TaskJoin(#[from] task::JoinError), } #[async_trait] impl AuthnBackend for UserRepository { type User = AuthenticatedUser; type Credentials = Credentials; type Error = Error; async fn authenticate( &self, creds: Self::Credentials, ) -> Result, Self::Error> { let user_found = entity::User::find() .filter(entity::user::Column::Username.eq(creds.username)) .one(&self.connection) .await?; if let Some(user) = user_found { let given_password = creds.password.clone(); let user_password = user.password.clone(); let verified = task::spawn_blocking(move || verify_password(&given_password, &user_password)) .await?; if verified.is_ok() { let mut db_user: entity::user::ActiveModel = user.into(); db_user.last_login = ActiveValue::Set(Some(get_current_timestamp())); let user = db_user.update(&self.connection).await?; let rear_user: AuthenticatedUser = user.into(); return Ok(Some(rear_user)); } } Ok(None) // No user found or verification failed } async fn get_user(&self, user_id: &UserId) -> Result, Self::Error> { let user_found = entity::User::find_by_id(*user_id) .one(&self.connection) .await?; if let Some(user) = user_found { let rear_user: AuthenticatedUser = user.into(); Ok(Some(rear_user)) } else { Ok(None) } } } #[derive(Debug, Clone, Eq, PartialEq, Hash)] pub struct Permission { pub name: String, } impl From<&str> for Permission { fn from(name: &str) -> Self { Permission { name: name.to_string(), } } } impl From for Permission { fn from(model: entity::permission::Model) -> Self { Permission { name: model.codename, } } } #[async_trait] impl AuthzBackend for UserRepository { type Permission = Permission; async fn get_group_permissions( &self, user: &Self::User, ) -> Result, Self::Error> { let user = entity::User::find_by_id(user.id) .one(&self.connection) .await?; if let Some(user) = user { let permissions = user .find_related(entity::Permission) .all(&self.connection) .await?; Ok(permissions .into_iter() .map(|item| Permission::from(item)) .collect()) } else { Ok(HashSet::new()) } } } // We use a type alias for convenience. // // Note that we've supplied our concrete backend here. pub type AuthSession = axum_login::AuthSession;