249 lines
7.2 KiB
Rust
249 lines
7.2 KiB
Rust
use std::collections::HashSet;
|
|
|
|
use async_trait::async_trait;
|
|
use axum_login::{AuthUser, AuthzBackend};
|
|
use axum_login::{AuthnBackend, UserId};
|
|
use log::debug;
|
|
use password_auth::{generate_hash, is_hash_obsolete, verify_password};
|
|
use sea_orm::{
|
|
ActiveModelTrait, ActiveValue, ColumnTrait, DatabaseConnection, EntityTrait, ModelTrait,
|
|
QueryFilter,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use thiserror::Error;
|
|
use tokio::task;
|
|
|
|
use crate::utils::get_current_timestamp;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct UserRepository {
|
|
pub(crate) connection: DatabaseConnection,
|
|
}
|
|
|
|
impl UserRepository {
|
|
pub fn new(connection: DatabaseConnection) -> Self {
|
|
UserRepository {
|
|
connection: connection,
|
|
}
|
|
}
|
|
|
|
pub(crate) fn encode_password(password: String) -> String {
|
|
// This function will try to avoid re-encoding an encoded password.
|
|
// This is why it is not public outside of the crate.
|
|
if let Ok(_is_obsolete) = is_hash_obsolete(password.as_str()) {
|
|
// Not sure what to do if it is obsolete.
|
|
if _is_obsolete {
|
|
debug!("UserRepository::encode_password: found obsolete password hash.");
|
|
}
|
|
return password;
|
|
}
|
|
// As checking for obsoleteness errored out, we assume this a raw password.
|
|
generate_hash(password)
|
|
}
|
|
|
|
pub fn new_user(username: &str, password: &str) -> entity::user::ActiveModel {
|
|
entity::user::ActiveModel {
|
|
username: sea_orm::ActiveValue::Set(username.to_owned()),
|
|
password: sea_orm::ActiveValue::Set(UserRepository::encode_password(
|
|
password.to_owned(),
|
|
)),
|
|
date_joined: sea_orm::ActiveValue::Set(Some(get_current_timestamp())),
|
|
..<entity::user::ActiveModel as ActiveModelTrait>::default()
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Clone, Serialize, Deserialize)]
|
|
pub struct AuthenticatedUser {
|
|
id: i64,
|
|
pub username: String,
|
|
password: String,
|
|
session_token: Vec<u8>, // Stores the hash
|
|
}
|
|
|
|
impl AuthenticatedUser {
|
|
fn new(id: i64, username: String, password: String) -> Self {
|
|
AuthenticatedUser {
|
|
id: id,
|
|
username: username,
|
|
session_token: blake3::hash(&password.as_bytes()).as_bytes().to_vec(),
|
|
password: password,
|
|
}
|
|
}
|
|
|
|
fn set_password(&mut self, new_password: String) {
|
|
self.password = new_password;
|
|
self.session_token = blake3::hash(self.password.as_bytes()).as_bytes().to_vec();
|
|
}
|
|
}
|
|
|
|
// Here we've implemented `Debug` manually to avoid accidentally logging the
|
|
// password hash.
|
|
impl std::fmt::Debug for AuthenticatedUser {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
f.debug_struct("User")
|
|
.field("id", &self.id)
|
|
.field("username", &self.username)
|
|
.field("password", &"[redacted]")
|
|
.finish()
|
|
}
|
|
}
|
|
|
|
impl AuthUser for AuthenticatedUser {
|
|
type Id = i64;
|
|
|
|
fn id(&self) -> Self::Id {
|
|
self.id
|
|
}
|
|
|
|
fn session_auth_hash(&self) -> &[u8] {
|
|
&self.session_token
|
|
}
|
|
}
|
|
|
|
impl From<entity::user::Model> for AuthenticatedUser {
|
|
fn from(value: entity::user::Model) -> Self {
|
|
AuthenticatedUser::new(value.id, value.username, value.password)
|
|
}
|
|
}
|
|
|
|
// This allows us to extract the authentication fields from forms. We use this
|
|
// to authenticate requests with the backend.
|
|
#[derive(Debug, Clone, Deserialize)]
|
|
pub struct Credentials {
|
|
pub username: String,
|
|
pub password: String,
|
|
pub next: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Error)]
|
|
pub struct NotFoundError {
|
|
pub details: String,
|
|
}
|
|
|
|
impl NotFoundError {
|
|
pub fn new(details: &str) -> Self {
|
|
NotFoundError {
|
|
details: details.to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for NotFoundError {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "Not Found... Error: {}", self.details)
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum Error {
|
|
#[error(transparent)]
|
|
DbErr(#[from] sea_orm::DbErr),
|
|
|
|
#[error("Not Found: {0}")]
|
|
NotFound(#[from] NotFoundError),
|
|
|
|
#[error(transparent)]
|
|
TaskJoin(#[from] task::JoinError),
|
|
}
|
|
|
|
#[async_trait]
|
|
impl AuthnBackend for UserRepository {
|
|
type User = AuthenticatedUser;
|
|
type Credentials = Credentials;
|
|
type Error = Error;
|
|
|
|
async fn authenticate(
|
|
&self,
|
|
creds: Self::Credentials,
|
|
) -> Result<Option<Self::User>, Self::Error> {
|
|
let user_found = entity::User::find()
|
|
.filter(entity::user::Column::Username.eq(creds.username))
|
|
.one(&self.connection)
|
|
.await?;
|
|
|
|
if let Some(user) = user_found {
|
|
let given_password = creds.password.clone();
|
|
let user_password = user.password.clone();
|
|
let verified =
|
|
task::spawn_blocking(move || verify_password(&given_password, &user_password))
|
|
.await?;
|
|
|
|
if verified.is_ok() {
|
|
let mut db_user: entity::user::ActiveModel = user.into();
|
|
db_user.last_login = ActiveValue::Set(Some(get_current_timestamp()));
|
|
let user = db_user.update(&self.connection).await?;
|
|
let rear_user: AuthenticatedUser = user.into();
|
|
return Ok(Some(rear_user));
|
|
}
|
|
}
|
|
|
|
Ok(None) // No user found or verification failed
|
|
}
|
|
|
|
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
|
|
let user_found = entity::User::find_by_id(*user_id)
|
|
.one(&self.connection)
|
|
.await?;
|
|
if let Some(user) = user_found {
|
|
let rear_user: AuthenticatedUser = user.into();
|
|
Ok(Some(rear_user))
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
|
|
pub struct Permission {
|
|
pub name: String,
|
|
}
|
|
|
|
impl From<&str> for Permission {
|
|
fn from(name: &str) -> Self {
|
|
Permission {
|
|
name: name.to_string(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<entity::permission::Model> for Permission {
|
|
fn from(model: entity::permission::Model) -> Self {
|
|
Permission {
|
|
name: model.codename,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl AuthzBackend for UserRepository {
|
|
type Permission = Permission;
|
|
|
|
async fn get_group_permissions(
|
|
&self,
|
|
user: &Self::User,
|
|
) -> Result<HashSet<Self::Permission>, Self::Error> {
|
|
let user = entity::User::find_by_id(user.id)
|
|
.one(&self.connection)
|
|
.await?;
|
|
if let Some(user) = user {
|
|
let permissions = user
|
|
.find_related(entity::Permission)
|
|
.all(&self.connection)
|
|
.await?;
|
|
|
|
Ok(permissions
|
|
.into_iter()
|
|
.map(|item| Permission::from(item))
|
|
.collect())
|
|
} else {
|
|
Ok(HashSet::new())
|
|
}
|
|
}
|
|
}
|
|
|
|
// We use a type alias for convenience.
|
|
//
|
|
// Note that we've supplied our concrete backend here.
|
|
pub type AuthSession = axum_login::AuthSession<UserRepository>;
|