Skip to content

Commit

Permalink
refactor secret again
Browse files Browse the repository at this point in the history
  • Loading branch information
robatipoor committed Jan 19, 2024
1 parent 3bc3800 commit 131ffd0
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 53 deletions.
22 changes: 11 additions & 11 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@ opt-level = "z"
strip = true

[workspace.dependencies]
anyhow = "1.0.78"
anyhow = "1.0.79"
argon2 = "0.5.2"
assert_cmd = "2.0.12"
assert_cmd = "2.0.13"
async-stream = {version = "0.3.5"}
async-trait = "0.1.76"
axum = {version = "0.7.3", features = ["multipart"]}
axum-extra = {version = "0.9.1", features = ["async-read-body"]}
base64 = "0.21.5"
async-trait = "0.1.77"
axum = {version = "0.7.4", features = ["multipart"]}
axum-extra = {version = "0.9.2", features = ["async-read-body"]}
base64 = "0.21.7"
bincode = "1.3.3"
build_html = "2.4.0"
chrono = {version = "0.4.31", features = ["serde"]}
clap = {version = "4.4.11", features = ["derive"]}
clap = {version = "4.4.18", features = ["derive"]}
config = {version = "0.13.4", default-features = false, features = ["toml"]}
cuid2 = "0.1.2"
fake = {version = "2.9.2", features = ['derive', 'uuid', 'chrono']}
Expand All @@ -53,16 +53,16 @@ qrcodegen = "1.8.0"
rand = "0.8.5"
regex = "1.10.2"
reqwest = {version = "0.11.23", default-features = false, features = ["json", "multipart", "stream", "rustls-tls"]}
serde = {version = "1.0.193", features = ["derive"]}
serde_json = "1.0.108"
serde = {version = "1.0.195", features = ["derive"]}
serde_json = "1.0.111"
sled = "0.34.7"
strum = { version = "0.25.0", features = ["derive"] }
test-context = "0.1.4"
thiserror = "1.0.53"
thiserror = "1.0.56"
tokio = {version = "1.35.1", features = ["macros", "time", "process", "net", "rt-multi-thread"]}
tokio-util = "0.7.10"
tower = {version = "0.4.13", features = ["util"]}
tower-http = {version = "0.5.0", features = ["fs"]}
tower-http = {version = "0.5.1", features = ["fs"]}
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
url = "2.5.0"
Expand Down
21 changes: 11 additions & 10 deletions api/src/database/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::sync::Arc;
use std::time::Duration;
use std::{collections::BTreeSet, path::PathBuf};

use crate::util::secret::SecretHash;
use crate::{
configure::DatabaseConfig,
error::{ApiError, ApiResult},
Expand Down Expand Up @@ -204,7 +205,7 @@ impl FilePath {
pub struct MetaDataFile {
pub created_at: DateTime<Utc>,
pub expiration_date: DateTime<Utc>,
pub auth: Option<String>,
pub secret: Option<SecretHash>,
pub delete_manually: bool,
pub max_download: Option<u32>,
pub count_downloads: u32,
Expand Down Expand Up @@ -309,7 +310,7 @@ mod tests {
let meta = MetaDataFile {
created_at: Utc::now(),
expiration_date: Utc::now() + chrono::Duration::seconds(10),
auth: None,
secret: None,
delete_manually: true,
max_download: None,
count_downloads: 1,
Expand All @@ -323,7 +324,7 @@ mod tests {
let result = ctx.state.db.fetch(&path).unwrap().unwrap();
assert_eq!(result.created_at, meta.created_at);
assert_eq!(result.expiration_date, meta.expiration_date);
assert_eq!(result.auth, meta.auth);
assert_eq!(result.secret, meta.secret);
assert_eq!(result.max_download, meta.max_download);
assert_eq!(result.count_downloads, meta.count_downloads);
}
Expand All @@ -335,7 +336,7 @@ mod tests {
let meta = MetaDataFile {
created_at: Utc::now(),
expiration_date: Utc::now() + chrono::Duration::seconds(10),
auth: None,
secret: None,
delete_manually: true,
max_download: None,
count_downloads: 0,
Expand All @@ -349,7 +350,7 @@ mod tests {
let result = ctx.state.db.fetch_count(&path).await.unwrap().unwrap();
assert_eq!(result.created_at, meta.created_at);
assert_eq!(result.expiration_date, meta.expiration_date);
assert_eq!(result.auth, meta.auth);
assert_eq!(result.secret, meta.secret);
assert_eq!(result.max_download, meta.max_download);
assert_eq!(result.count_downloads, meta.count_downloads);
}
Expand All @@ -361,7 +362,7 @@ mod tests {
let meta = MetaDataFile {
created_at: Utc::now(),
expiration_date: Utc::now() + chrono::Duration::seconds(10),
auth: None,
secret: None,
delete_manually: true,
max_download: None,
count_downloads: 0,
Expand All @@ -376,7 +377,7 @@ mod tests {
let result = ctx.state.db.fetch_count(&path).await.unwrap().unwrap();
assert_eq!(result.created_at, meta.created_at);
assert_eq!(result.expiration_date, meta.expiration_date);
assert_eq!(result.auth, meta.auth);
assert_eq!(result.secret, meta.secret);
assert_eq!(result.max_download, meta.max_download);
assert_eq!(result.count_downloads, meta.count_downloads + 1);
}
Expand All @@ -388,7 +389,7 @@ mod tests {
let meta = MetaDataFile {
created_at: Utc::now(),
expiration_date: Utc::now() + chrono::Duration::seconds(10),
auth: None,
secret: None,
delete_manually: true,
max_download: None,
count_downloads: 0,
Expand All @@ -410,7 +411,7 @@ mod tests {
let meta = MetaDataFile {
created_at: Utc::now(),
expiration_date: Utc::now(),
auth: None,
secret: None,
delete_manually: true,
max_download: None,
count_downloads: 0,
Expand All @@ -433,7 +434,7 @@ mod tests {
let meta = MetaDataFile {
created_at: Utc::now(),
expiration_date: Utc::now() + chrono::Duration::seconds(10),
auth: None,
secret: None,
delete_manually: true,
max_download: None,
count_downloads: 0,
Expand Down
37 changes: 15 additions & 22 deletions api/src/service/file.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::error::{ApiError, ApiResult, ToApiResult};
use crate::util::secret::Secret;
use crate::util::secret::{Secret, SecretHash};
use anyhow::anyhow;
use axum::extract::multipart::Field;
use axum::extract::Multipart;
Expand All @@ -22,7 +22,7 @@ pub async fn store(
secret: Option<Secret>,
mut multipart: Multipart,
) -> ApiResult<(FilePath, DateTime<Utc>)> {
let auth = secret.map(|s| s.hash()).transpose()?;
let secret = secret.map(|s| s.hash()).transpose()?;
let expire_secs = query
.expire_time
.unwrap_or(state.config.default_expire_secs) as i64;
Expand All @@ -36,7 +36,7 @@ pub async fn store(
expiration_date,
delete_manually: query.delete_manually.unwrap_or(true),
max_download: query.max_download,
auth,
secret,
count_downloads: 0,
};
while let Ok(Some(field)) = multipart.next_field().await {
Expand Down Expand Up @@ -72,7 +72,9 @@ pub async fn store(
state.db.flush().await?;
return Ok((path, expiration_date));
}
Err(ApiError::BadRequest("multpart is empty".to_string()))
Err(ApiError::BadRequest(
"multipart/form-data empty body".to_string(),
))
}

pub async fn store_stream(file_path: &PathBuf, field: Field<'_>) -> ApiResult<()> {
Expand All @@ -91,7 +93,7 @@ pub async fn info(
state: &ApiState,
code: &str,
file_name: &str,
auth: Option<Secret>,
secret: Option<Secret>,
) -> ApiResult<MetaDataFile> {
let path = FilePath {
code: code.to_string(),
Expand All @@ -104,7 +106,7 @@ pub async fn info(
return Err(ApiError::NotFound(format!("{} not found", path.url_path())));
}
}
authenticate(auth, &meta.auth)?;
authorize_user(secret, &meta.secret)?;
Ok(meta)
}

Expand All @@ -125,7 +127,7 @@ pub async fn fetch(
return Err(ApiError::NotFound(format!("{} not found", path.url_path())));
}
}
authenticate(secret, &meta.auth)?;
authorize_user(secret, &meta.secret)?;
read_file(&state.config.fs.base_dir.join(&path.url_path())).await
}

Expand All @@ -141,7 +143,7 @@ pub async fn delete(
};
if let Some(meta) = state.db.fetch(&path)? {
if meta.delete_manually {
authenticate(secret, &meta.auth)?;
authorize_user(secret, &meta.secret)?;
let file_path = path.fs_path(&state.config.fs.base_dir);
tokio::fs::remove_file(file_path).await?;
state.db.delete(path).await?;
Expand All @@ -160,34 +162,25 @@ pub async fn read_file(file_path: &PathBuf) -> ApiResult<ServeFile> {
Ok(ServeFile::new(file_path))
}

pub fn authenticate(secret: Option<Secret>, hash: &Option<String>) -> ApiResult<()> {
if let Some(hash) = hash {
match secret.map(|s| s.check(hash)) {
pub fn authorize_user(secret: Option<Secret>, secret_hash: &Option<SecretHash>) -> ApiResult<()> {
if let Some(hash) = secret_hash {
match secret.map(|s| s.verify(hash)) {
Some(Ok(_)) => return Ok(()),
Some(Err(e)) if e == argon2::password_hash::Error::Password => Err(
ApiError::PermissionDenied("Secret token is invalid".to_string()),
),
Some(Err(e)) => Err(ApiError::Unknown(anyhow!(
"Unexpected error happened: {}",
e
"An Unexpected error occurred: {e}",
))),
None => Err(ApiError::PermissionDenied(
"Authorization header should be set".to_string(),
"Authorization header required.".to_string(),
)),
}
} else {
Ok(())
}
}

pub fn hash(auth: Option<String>) -> ApiResult<Option<String>> {
auth
.as_ref()
.map(crate::util::hash::argon_hash)
.transpose()
.map_err(|e| ApiError::HashError(e.to_string()))
}

pub fn calc_expiration_date(now: DateTime<Utc>, secs: i64) -> DateTime<Utc> {
now + chrono::Duration::seconds(secs)
}
2 changes: 1 addition & 1 deletion api/src/util/http.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use hyper::HeaderMap;

use crate::error::{invalid_input_error, ApiError, ApiResult};
use crate::error::{invalid_input_error, ApiResult};

use super::secret::Secret;

Expand Down
21 changes: 12 additions & 9 deletions api/src/util/secret.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
use crate::error::{ApiError, ApiResult};

#[derive(Debug)]
pub struct Secret {
inner: String,
}
#[derive(Debug, Clone, Eq, PartialEq, PartialOrd, Ord)]
pub struct Secret(String);

#[derive(Debug, serde::Serialize, serde::Deserialize, Clone, Eq, PartialEq, PartialOrd, Ord)]
pub struct SecretHash(String);

impl Secret {
pub fn new(secret: String) -> Self {
Self { inner: secret }
Self(secret)
}

pub fn check(&self, hash: &str) -> Result<(), argon2::password_hash::Error> {
crate::util::hash::argon_verify(&self.inner, hash)
pub fn verify(&self, hash: &SecretHash) -> Result<(), argon2::password_hash::Error> {
crate::util::hash::argon_verify(&self.0, &hash.0)
}

pub fn hash(&self) -> ApiResult<String> {
crate::util::hash::argon_hash(&self.inner).map_err(|e| ApiError::HashError(e.to_string()))
pub fn hash(&self) -> ApiResult<SecretHash> {
crate::util::hash::argon_hash(&self.0)
.map_err(|e| ApiError::HashError(e.to_string()))
.map(|hash| SecretHash(hash))
}
}

0 comments on commit 131ffd0

Please sign in to comment.