mirror of
https://github.com/atuinsh/atuin.git
synced 2025-07-07 13:15:09 +00:00
feat: Add sqlite server support for self-hosting (#2770)
* Move db_uri setting to DbSettings * WIP: sqlite crate framework * WIP: Migrations * WIP: sqlite implementation * Add sqlite3 to Docker image * verified_at needed for user query * chore(deps): bump debian (#2772) Bumps debian from bookworm-20250428-slim to bookworm-20250520-slim. --- updated-dependencies: - dependency-name: debian dependency-version: bookworm-20250520-slim dependency-type: direct:production ... Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> * fix(doctor): mention the required ble.sh version (#2774) References: https://forum.atuin.sh/t/1047 * fix: Don't print errors in `zsh_autosuggest` helper (#2780) Previously, this would result in long multi-line errors when typing, making it hard to see the shell prompt: ``` $ Error: could not load client settings Caused by: 0: could not create config file 1: failed to create file `/home/jyn/.config/atuin/config.toml` 2: Required key not available (os error 126) Location: atuin-client/src/settings.rs:675:54 fError: could not load client settings Caused by: 0: could not create config file 1: failed to create file `/home/jyn/.config/atuin/config.toml` 2: Required key not available (os error 126) Location: atuin-client/src/settings.rs:675:54 faError: could not load client settings ``` Silence these in autosuggestions, such that they only show up when explicitly invoking atuin. * fix: `atuin.nu` enchancements (#2778) * PR feedback * Remove sqlite3 package * fix(search): prevent panic on malformed format strings (#2776) (#2777) * fix(search): prevent panic on malformed format strings (#2776) - Wrap format operations in panic catcher for graceful error handling - Improve error messages with context-aware guidance for common issues - Let runtime-format parser handle validation to avoid blocking valid formats Fixes crash when using malformed format strings by catching formatting errors gracefully and providing actionable guidance without restricting legitimate format patterns like {command} or {time}. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> * Satisfy cargo fmt * test(search): add regression tests for format string panic (#2776) - Add test for malformed JSON format strings that previously caused panics - Add test to ensure valid format strings continue to work - Prevent future regressions of the format string panic issue 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com> --------- Signed-off-by: dependabot[bot] <support@github.com> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Koichi Murase <myoga.murase@gmail.com> Co-authored-by: jyn <github@jyn.dev> Co-authored-by: Tyarel8 <98483313+Tyarel8@users.noreply.github.com> Co-authored-by: Brian Cosgrove <cosgroveb@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
parent
306f5e1104
commit
7f868711f0
23 changed files with 824 additions and 49 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -10,3 +10,5 @@ publish.sh
|
|||
|
||||
ui/backend/target
|
||||
ui/backend/gen
|
||||
|
||||
sqlite-server.db*
|
||||
|
|
21
Cargo.lock
generated
21
Cargo.lock
generated
|
@ -233,7 +233,9 @@ dependencies = [
|
|||
"atuin-kv",
|
||||
"atuin-scripts",
|
||||
"atuin-server",
|
||||
"atuin-server-database",
|
||||
"atuin-server-postgres",
|
||||
"atuin-server-sqlite",
|
||||
"clap",
|
||||
"clap_complete",
|
||||
"clap_complete_nushell",
|
||||
|
@ -473,6 +475,7 @@ dependencies = [
|
|||
"serde",
|
||||
"time",
|
||||
"tracing",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -489,7 +492,23 @@ dependencies = [
|
|||
"sqlx",
|
||||
"time",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atuin-server-sqlite"
|
||||
version = "18.6.1"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"atuin-common",
|
||||
"atuin-server-database",
|
||||
"eyre",
|
||||
"futures-util",
|
||||
"metrics",
|
||||
"serde",
|
||||
"sqlx",
|
||||
"time",
|
||||
"tracing",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
|
|
|
@ -17,3 +17,4 @@ time = { workspace = true }
|
|||
eyre = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
url = "2.5.2"
|
||||
|
|
|
@ -15,7 +15,7 @@ use self::{
|
|||
};
|
||||
use async_trait::async_trait;
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use time::{Date, Duration, Month, OffsetDateTime, Time, UtcOffset};
|
||||
use tracing::instrument;
|
||||
|
||||
|
@ -41,10 +41,54 @@ impl std::error::Error for DbError {}
|
|||
|
||||
pub type DbResult<T> = Result<T, DbError>;
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum DbType {
|
||||
Postgres,
|
||||
Sqlite,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub struct DbSettings {
|
||||
pub db_uri: String,
|
||||
}
|
||||
|
||||
impl DbSettings {
|
||||
pub fn db_type(&self) -> DbType {
|
||||
if self.db_uri.starts_with("postgres://") {
|
||||
DbType::Postgres
|
||||
} else if self.db_uri.starts_with("sqlite://") {
|
||||
DbType::Sqlite
|
||||
} else {
|
||||
DbType::Unknown
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do our best to redact passwords so they're not logged in the event of an error.
|
||||
impl Debug for DbSettings {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if self.db_type() == DbType::Postgres {
|
||||
let redacted_uri = url::Url::parse(&self.db_uri)
|
||||
.map(|mut url| {
|
||||
let _ = url.set_password(Some("****"));
|
||||
url.to_string()
|
||||
})
|
||||
.unwrap_or(self.db_uri.clone());
|
||||
f.debug_struct("DbSettings")
|
||||
.field("db_uri", &redacted_uri)
|
||||
.finish()
|
||||
} else {
|
||||
f.debug_struct("DbSettings")
|
||||
.field("db_uri", &self.db_uri)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Database: Sized + Clone + Send + Sync + 'static {
|
||||
type Settings: Debug + Clone + DeserializeOwned + Serialize + Send + Sync + 'static;
|
||||
async fn new(settings: &Self::Settings) -> DbResult<Self>;
|
||||
async fn new(settings: &DbSettings) -> DbResult<Self>;
|
||||
|
||||
async fn get_session(&self, token: &str) -> DbResult<Session>;
|
||||
async fn get_session_user(&self, token: &str) -> DbResult<User>;
|
||||
|
|
|
@ -22,4 +22,3 @@ async-trait = { workspace = true }
|
|||
uuid = { workspace = true }
|
||||
metrics = "0.21.1"
|
||||
futures-util = "0.3"
|
||||
url = "2.5.2"
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::ops::Range;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use atuin_common::record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus};
|
||||
use atuin_common::utils::crypto_random_string;
|
||||
use atuin_server_database::models::{History, NewHistory, NewSession, NewUser, Session, User};
|
||||
use atuin_server_database::{Database, DbError, DbResult};
|
||||
use atuin_server_database::{Database, DbError, DbResult, DbSettings};
|
||||
use futures_util::TryStreamExt;
|
||||
use metrics::counter;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::Row;
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
|
||||
|
@ -27,26 +25,6 @@ pub struct Postgres {
|
|||
pool: sqlx::Pool<sqlx::postgres::Postgres>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Deserialize, Serialize)]
|
||||
pub struct PostgresSettings {
|
||||
pub db_uri: String,
|
||||
}
|
||||
|
||||
// Do our best to redact passwords so they're not logged in the event of an error.
|
||||
impl Debug for PostgresSettings {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let redacted_uri = url::Url::parse(&self.db_uri)
|
||||
.map(|mut url| {
|
||||
let _ = url.set_password(Some("****"));
|
||||
url.to_string()
|
||||
})
|
||||
.unwrap_or(self.db_uri.clone());
|
||||
f.debug_struct("PostgresSettings")
|
||||
.field("db_uri", &redacted_uri)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
fn fix_error(error: sqlx::Error) -> DbError {
|
||||
match error {
|
||||
sqlx::Error::RowNotFound => DbError::NotFound,
|
||||
|
@ -56,8 +34,7 @@ fn fix_error(error: sqlx::Error) -> DbError {
|
|||
|
||||
#[async_trait]
|
||||
impl Database for Postgres {
|
||||
type Settings = PostgresSettings;
|
||||
async fn new(settings: &PostgresSettings) -> DbResult<Self> {
|
||||
async fn new(settings: &DbSettings) -> DbResult<Self> {
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(100)
|
||||
.connect(settings.db_uri.as_str())
|
||||
|
|
24
crates/atuin-server-sqlite/Cargo.toml
Normal file
24
crates/atuin-server-sqlite/Cargo.toml
Normal file
|
@ -0,0 +1,24 @@
|
|||
[package]
|
||||
name = "atuin-server-sqlite"
|
||||
edition = "2024"
|
||||
description = "server sqlite database library for atuin"
|
||||
|
||||
version = { workspace = true }
|
||||
authors = { workspace = true }
|
||||
license = { workspace = true }
|
||||
homepage = { workspace = true }
|
||||
repository = { workspace = true }
|
||||
|
||||
[dependencies]
|
||||
atuin-common = { path = "../atuin-common", version = "18.6.1" }
|
||||
atuin-server-database = { path = "../atuin-server-database", version = "18.6.1" }
|
||||
|
||||
eyre = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
time = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
metrics = "0.21.1"
|
||||
futures-util = "0.3"
|
5
crates/atuin-server-sqlite/build.rs
Normal file
5
crates/atuin-server-sqlite/build.rs
Normal file
|
@ -0,0 +1,5 @@
|
|||
// generated by `sqlx migrate build-script`
|
||||
fn main() {
|
||||
// trigger recompilation when a new migration is added
|
||||
println!("cargo:rerun-if-changed=migrations");
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
create table store (
|
||||
id text primary key, -- remember to use uuidv7 for happy indices <3
|
||||
client_id text not null, -- I am too uncomfortable with the idea of a client-generated primary key, even though it's fine mathematically
|
||||
host text not null, -- a unique identifier for the host
|
||||
idx bigint not null, -- the index of the record in this store, identified by (host, tag)
|
||||
timestamp bigint not null, -- not a timestamp type, as those do not have nanosecond precision
|
||||
version text not null,
|
||||
tag text not null, -- what is this? history, kv, whatever. Remember clients get a log per tag per host
|
||||
data text not null, -- store the actual history data, encrypted. I don't wanna know!
|
||||
cek text not null,
|
||||
|
||||
user_id bigint not null, -- allow multiple users
|
||||
created_at timestamp not null default current_timestamp
|
||||
);
|
||||
|
||||
create unique index record_uniq ON store(user_id, host, tag, idx);
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
create table history (
|
||||
id integer primary key autoincrement,
|
||||
client_id text not null unique, -- the client-generated ID
|
||||
user_id bigserial not null, -- allow multiple users
|
||||
hostname text not null, -- a unique identifier from the client (can be hashed, random, whatever)
|
||||
timestamp timestamp not null, -- one of the few non-encrypted metadatas
|
||||
|
||||
data text not null, -- store the actual history data, encrypted. I don't wanna know!
|
||||
|
||||
created_at timestamp not null default current_timestamp,
|
||||
deleted_at timestamp
|
||||
);
|
||||
|
||||
create unique index history_deleted_index on history(client_id, user_id, deleted_at);
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
create table sessions (
|
||||
id integer primary key autoincrement,
|
||||
user_id integer,
|
||||
token text unique not null
|
||||
);
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
create table users (
|
||||
id integer primary key autoincrement, -- also store our own ID
|
||||
username text not null unique, -- being able to contact users is useful
|
||||
email text not null unique, -- being able to contact users is useful
|
||||
password text not null unique,
|
||||
created_at timestamp not null default (datetime('now','localtime')),
|
||||
verified_at timestamp with time zone default null
|
||||
);
|
||||
|
||||
-- the prior index is case sensitive :(
|
||||
CREATE UNIQUE INDEX email_unique_idx on users (LOWER(email));
|
||||
CREATE UNIQUE INDEX username_unique_idx on users (LOWER(username));
|
|
@ -0,0 +1,6 @@
|
|||
create table user_verification_token(
|
||||
id integer primary key autoincrement,
|
||||
user_id bigint unique references users(id),
|
||||
token text,
|
||||
valid_until timestamp with time zone
|
||||
);
|
|
@ -0,0 +1,10 @@
|
|||
create table store_idx_cache(
|
||||
id integer primary key autoincrement,
|
||||
user_id bigint,
|
||||
|
||||
host uuid,
|
||||
tag text,
|
||||
idx bigint
|
||||
);
|
||||
|
||||
create unique index store_idx_cache_uniq on store_idx_cache(user_id, host, tag);
|
552
crates/atuin-server-sqlite/src/lib.rs
Normal file
552
crates/atuin-server-sqlite/src/lib.rs
Normal file
|
@ -0,0 +1,552 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use atuin_common::{
|
||||
record::{EncryptedData, HostId, Record, RecordIdx, RecordStatus},
|
||||
utils::crypto_random_string,
|
||||
};
|
||||
use atuin_server_database::{
|
||||
Database, DbError, DbResult, DbSettings,
|
||||
models::{History, NewHistory, NewSession, NewUser, Session, User},
|
||||
};
|
||||
use futures_util::TryStreamExt;
|
||||
use sqlx::{
|
||||
Row,
|
||||
sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions},
|
||||
types::Uuid,
|
||||
};
|
||||
use time::{OffsetDateTime, PrimitiveDateTime, UtcOffset};
|
||||
use tracing::instrument;
|
||||
use wrappers::{DbHistory, DbRecord, DbSession, DbUser};
|
||||
|
||||
mod wrappers;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Sqlite {
|
||||
pool: sqlx::Pool<sqlx::sqlite::Sqlite>,
|
||||
}
|
||||
|
||||
fn fix_error(error: sqlx::Error) -> DbError {
|
||||
match error {
|
||||
sqlx::Error::RowNotFound => DbError::NotFound,
|
||||
error => DbError::Other(error.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Database for Sqlite {
|
||||
async fn new(settings: &DbSettings) -> DbResult<Self> {
|
||||
let opts = SqliteConnectOptions::from_str(&settings.db_uri)
|
||||
.map_err(fix_error)?
|
||||
.journal_mode(SqliteJournalMode::Wal)
|
||||
.create_if_missing(true);
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.connect_with(opts)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
sqlx::migrate!("./migrations")
|
||||
.run(&pool)
|
||||
.await
|
||||
.map_err(|error| DbError::Other(error.into()))?;
|
||||
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn get_session(&self, token: &str) -> DbResult<Session> {
|
||||
sqlx::query_as("select id, user_id, token from sessions where token = $1")
|
||||
.bind(token)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)
|
||||
.map(|DbSession(session)| session)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn get_session_user(&self, token: &str) -> DbResult<User> {
|
||||
sqlx::query_as(
|
||||
"select users.id, users.username, users.email, users.password, users.verified_at from users
|
||||
inner join sessions
|
||||
on users.id = sessions.user_id
|
||||
and sessions.token = $1",
|
||||
)
|
||||
.bind(token)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)
|
||||
.map(|DbUser(user)| user)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn add_session(&self, session: &NewSession) -> DbResult<()> {
|
||||
let token: &str = &session.token;
|
||||
|
||||
sqlx::query(
|
||||
"insert into sessions
|
||||
(user_id, token)
|
||||
values($1, $2)",
|
||||
)
|
||||
.bind(session.user_id)
|
||||
.bind(token)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn get_user(&self, username: &str) -> DbResult<User> {
|
||||
sqlx::query_as(
|
||||
"select id, username, email, password, verified_at from users where username = $1",
|
||||
)
|
||||
.bind(username)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)
|
||||
.map(|DbUser(user)| user)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn get_user_session(&self, u: &User) -> DbResult<Session> {
|
||||
sqlx::query_as("select id, user_id, token from sessions where user_id = $1")
|
||||
.bind(u.id)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)
|
||||
.map(|DbSession(session)| session)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn add_user(&self, user: &NewUser) -> DbResult<i64> {
|
||||
let email: &str = &user.email;
|
||||
let username: &str = &user.username;
|
||||
let password: &str = &user.password;
|
||||
|
||||
let res: (i64,) = sqlx::query_as(
|
||||
"insert into users
|
||||
(username, email, password)
|
||||
values($1, $2, $3)
|
||||
returning id",
|
||||
)
|
||||
.bind(username)
|
||||
.bind(email)
|
||||
.bind(password)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(res.0)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn user_verified(&self, id: i64) -> DbResult<bool> {
|
||||
let res: (bool,) =
|
||||
sqlx::query_as("select verified_at is not null from users where id = $1")
|
||||
.bind(id)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(res.0)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn verify_user(&self, id: i64) -> DbResult<()> {
|
||||
sqlx::query(
|
||||
"update users set verified_at = (current_timestamp at time zone 'utc') where id=$1",
|
||||
)
|
||||
.bind(id)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn user_verification_token(&self, id: i64) -> DbResult<String> {
|
||||
const TOKEN_VALID_MINUTES: i64 = 15;
|
||||
|
||||
// First we check if there is a verification token
|
||||
let token: Option<(String, sqlx::types::time::OffsetDateTime)> = sqlx::query_as(
|
||||
"select token, valid_until from user_verification_token where user_id = $1",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
let token = if let Some((token, valid_until)) = token {
|
||||
// We have a token, AND it's still valid
|
||||
if valid_until > time::OffsetDateTime::now_utc() {
|
||||
token
|
||||
} else {
|
||||
// token has expired. generate a new one, return it
|
||||
let token = crypto_random_string::<24>();
|
||||
|
||||
sqlx::query("update user_verification_token set token = $2, valid_until = $3 where user_id=$1")
|
||||
.bind(id)
|
||||
.bind(&token)
|
||||
.bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
token
|
||||
}
|
||||
} else {
|
||||
// No token in the database! Generate one, insert it
|
||||
let token = crypto_random_string::<24>();
|
||||
|
||||
sqlx::query("insert into user_verification_token (user_id, token, valid_until) values ($1, $2, $3)")
|
||||
.bind(id)
|
||||
.bind(&token)
|
||||
.bind(time::OffsetDateTime::now_utc() + time::Duration::minutes(TOKEN_VALID_MINUTES))
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
token
|
||||
};
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn update_user_password(&self, user: &User) -> DbResult<()> {
|
||||
sqlx::query(
|
||||
"update users
|
||||
set password = $1
|
||||
where id = $2",
|
||||
)
|
||||
.bind(&user.password)
|
||||
.bind(user.id)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn total_history(&self) -> DbResult<i64> {
|
||||
let res: (i64,) = sqlx::query_as("select count(1) from history")
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?
|
||||
.unwrap_or((0,));
|
||||
|
||||
Ok(res.0)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn count_history(&self, user: &User) -> DbResult<i64> {
|
||||
// The cache is new, and the user might not yet have a cache value.
|
||||
// They will have one as soon as they post up some new history, but handle that
|
||||
// edge case.
|
||||
|
||||
let res: (i64,) = sqlx::query_as(
|
||||
"select count(1) from history
|
||||
where user_id = $1",
|
||||
)
|
||||
.bind(user.id)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(res.0)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn count_history_cached(&self, _user: &User) -> DbResult<i64> {
|
||||
Err(DbError::NotFound)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn delete_user(&self, u: &User) -> DbResult<()> {
|
||||
sqlx::query("delete from sessions where user_id = $1")
|
||||
.bind(u.id)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
sqlx::query("delete from users where id = $1")
|
||||
.bind(u.id)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
sqlx::query("delete from history where user_id = $1")
|
||||
.bind(u.id)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_history(&self, user: &User, id: String) -> DbResult<()> {
|
||||
sqlx::query(
|
||||
"update history
|
||||
set deleted_at = $3
|
||||
where user_id = $1
|
||||
and client_id = $2
|
||||
and deleted_at is null", // don't just keep setting it
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(id)
|
||||
.bind(time::OffsetDateTime::now_utc())
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn deleted_history(&self, user: &User) -> DbResult<Vec<String>> {
|
||||
// The cache is new, and the user might not yet have a cache value.
|
||||
// They will have one as soon as they post up some new history, but handle that
|
||||
// edge case.
|
||||
|
||||
let res = sqlx::query(
|
||||
"select client_id from history
|
||||
where user_id = $1
|
||||
and deleted_at is not null",
|
||||
)
|
||||
.bind(user.id)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
let res = res.iter().map(|row| row.get("client_id")).collect();
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
async fn delete_store(&self, user: &User) -> DbResult<()> {
|
||||
sqlx::query(
|
||||
"delete from store
|
||||
where user_id = $1",
|
||||
)
|
||||
.bind(user.id)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn add_records(&self, user: &User, records: &[Record<EncryptedData>]) -> DbResult<()> {
|
||||
let mut tx = self.pool.begin().await.map_err(fix_error)?;
|
||||
|
||||
for i in records {
|
||||
let id = atuin_common::utils::uuid_v7();
|
||||
|
||||
sqlx::query(
|
||||
"insert into store
|
||||
(id, client_id, host, idx, timestamp, version, tag, data, cek, user_id)
|
||||
values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
on conflict do nothing
|
||||
",
|
||||
)
|
||||
.bind(id)
|
||||
.bind(i.id)
|
||||
.bind(i.host.id)
|
||||
.bind(i.idx as i64)
|
||||
.bind(i.timestamp as i64) // throwing away some data, but i64 is still big in terms of time
|
||||
.bind(&i.version)
|
||||
.bind(&i.tag)
|
||||
.bind(&i.data.data)
|
||||
.bind(&i.data.content_encryption_key)
|
||||
.bind(user.id)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
}
|
||||
|
||||
tx.commit().await.map_err(fix_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn next_records(
|
||||
&self,
|
||||
user: &User,
|
||||
host: HostId,
|
||||
tag: String,
|
||||
start: Option<RecordIdx>,
|
||||
count: u64,
|
||||
) -> DbResult<Vec<Record<EncryptedData>>> {
|
||||
tracing::debug!("{:?} - {:?} - {:?}", host, tag, start);
|
||||
let start = start.unwrap_or(0);
|
||||
|
||||
let records: Result<Vec<DbRecord>, DbError> = sqlx::query_as(
|
||||
"select client_id, host, idx, timestamp, version, tag, data, cek from store
|
||||
where user_id = $1
|
||||
and tag = $2
|
||||
and host = $3
|
||||
and idx >= $4
|
||||
order by idx asc
|
||||
limit $5",
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(tag.clone())
|
||||
.bind(host)
|
||||
.bind(start as i64)
|
||||
.bind(count as i64)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error);
|
||||
|
||||
let ret = match records {
|
||||
Ok(records) => {
|
||||
let records: Vec<Record<EncryptedData>> = records
|
||||
.into_iter()
|
||||
.map(|f| {
|
||||
let record: Record<EncryptedData> = f.into();
|
||||
record
|
||||
})
|
||||
.collect();
|
||||
|
||||
records
|
||||
}
|
||||
Err(DbError::NotFound) => {
|
||||
tracing::debug!("no records found in store: {:?}/{}", host, tag);
|
||||
return Ok(vec![]);
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
};
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
async fn status(&self, user: &User) -> DbResult<RecordStatus> {
|
||||
const STATUS_SQL: &str =
|
||||
"select host, tag, max(idx) from store where user_id = $1 group by host, tag";
|
||||
|
||||
let res: Vec<(Uuid, String, i64)> = sqlx::query_as(STATUS_SQL)
|
||||
.bind(user.id)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
let mut status = RecordStatus::new();
|
||||
|
||||
for i in res {
|
||||
status.set_raw(HostId(i.0), i.1, i.2 as u64);
|
||||
}
|
||||
|
||||
Ok(status)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn count_history_range(
|
||||
&self,
|
||||
user: &User,
|
||||
range: std::ops::Range<time::OffsetDateTime>,
|
||||
) -> DbResult<i64> {
|
||||
let res: (i64,) = sqlx::query_as(
|
||||
"select count(1) from history
|
||||
where user_id = $1
|
||||
and timestamp >= $2::date
|
||||
and timestamp < $3::date",
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(into_utc(range.start))
|
||||
.bind(into_utc(range.end))
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(res.0)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn list_history(
|
||||
&self,
|
||||
user: &User,
|
||||
created_after: time::OffsetDateTime,
|
||||
since: time::OffsetDateTime,
|
||||
host: &str,
|
||||
page_size: i64,
|
||||
) -> DbResult<Vec<History>> {
|
||||
let res = sqlx::query_as(
|
||||
"select id, client_id, user_id, hostname, timestamp, data, created_at from history
|
||||
where user_id = $1
|
||||
and hostname != $2
|
||||
and created_at >= $3
|
||||
and timestamp >= $4
|
||||
order by timestamp asc
|
||||
limit $5",
|
||||
)
|
||||
.bind(user.id)
|
||||
.bind(host)
|
||||
.bind(into_utc(created_after))
|
||||
.bind(into_utc(since))
|
||||
.bind(page_size)
|
||||
.fetch(&self.pool)
|
||||
.map_ok(|DbHistory(h)| h)
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn add_history(&self, history: &[NewHistory]) -> DbResult<()> {
|
||||
let mut tx = self.pool.begin().await.map_err(fix_error)?;
|
||||
|
||||
for i in history {
|
||||
let client_id: &str = &i.client_id;
|
||||
let hostname: &str = &i.hostname;
|
||||
let data: &str = &i.data;
|
||||
|
||||
sqlx::query(
|
||||
"insert into history
|
||||
(client_id, user_id, hostname, timestamp, data)
|
||||
values ($1, $2, $3, $4, $5)
|
||||
on conflict do nothing
|
||||
",
|
||||
)
|
||||
.bind(client_id)
|
||||
.bind(i.user_id)
|
||||
.bind(hostname)
|
||||
.bind(i.timestamp)
|
||||
.bind(data)
|
||||
.execute(&mut *tx)
|
||||
.await
|
||||
.map_err(fix_error)?;
|
||||
}
|
||||
|
||||
tx.commit().await.map_err(fix_error)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
async fn oldest_history(&self, user: &User) -> DbResult<History> {
|
||||
sqlx::query_as(
|
||||
"select id, client_id, user_id, hostname, timestamp, data, created_at from history
|
||||
where user_id = $1
|
||||
order by timestamp asc
|
||||
limit 1",
|
||||
)
|
||||
.bind(user.id)
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(fix_error)
|
||||
.map(|DbHistory(h)| h)
|
||||
}
|
||||
}
|
||||
|
||||
fn into_utc(x: OffsetDateTime) -> PrimitiveDateTime {
|
||||
let x = x.to_offset(UtcOffset::UTC);
|
||||
PrimitiveDateTime::new(x.date(), x.time())
|
||||
}
|
73
crates/atuin-server-sqlite/src/wrappers.rs
Normal file
73
crates/atuin-server-sqlite/src/wrappers.rs
Normal file
|
@ -0,0 +1,73 @@
|
|||
use ::sqlx::{FromRow, Result};
|
||||
use atuin_common::record::{EncryptedData, Host, Record};
|
||||
use atuin_server_database::models::{History, Session, User};
|
||||
use sqlx::{Row, sqlite::SqliteRow};
|
||||
|
||||
pub struct DbUser(pub User);
|
||||
pub struct DbSession(pub Session);
|
||||
pub struct DbHistory(pub History);
|
||||
pub struct DbRecord(pub Record<EncryptedData>);
|
||||
|
||||
impl<'a> FromRow<'a, SqliteRow> for DbUser {
|
||||
fn from_row(row: &'a SqliteRow) -> Result<Self> {
|
||||
Ok(Self(User {
|
||||
id: row.try_get("id")?,
|
||||
username: row.try_get("username")?,
|
||||
email: row.try_get("email")?,
|
||||
password: row.try_get("password")?,
|
||||
verified: row.try_get("verified_at")?,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbSession {
|
||||
fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> {
|
||||
Ok(Self(Session {
|
||||
id: row.try_get("id")?,
|
||||
user_id: row.try_get("user_id")?,
|
||||
token: row.try_get("token")?,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbHistory {
|
||||
fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> {
|
||||
Ok(Self(History {
|
||||
id: row.try_get("id")?,
|
||||
client_id: row.try_get("client_id")?,
|
||||
user_id: row.try_get("user_id")?,
|
||||
hostname: row.try_get("hostname")?,
|
||||
timestamp: row.try_get("timestamp")?,
|
||||
data: row.try_get("data")?,
|
||||
created_at: row.try_get("created_at")?,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ::sqlx::FromRow<'a, SqliteRow> for DbRecord {
|
||||
fn from_row(row: &'a SqliteRow) -> ::sqlx::Result<Self> {
|
||||
let idx: i64 = row.try_get("idx")?;
|
||||
let timestamp: i64 = row.try_get("timestamp")?;
|
||||
|
||||
let data = EncryptedData {
|
||||
data: row.try_get("data")?,
|
||||
content_encryption_key: row.try_get("cek")?,
|
||||
};
|
||||
|
||||
Ok(Self(Record {
|
||||
id: row.try_get("client_id")?,
|
||||
host: Host::new(row.try_get("host")?),
|
||||
idx: idx as u64,
|
||||
timestamp: timestamp as u64,
|
||||
version: row.try_get("version")?,
|
||||
tag: row.try_get("tag")?,
|
||||
data,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DbRecord> for Record<EncryptedData> {
|
||||
fn from(other: DbRecord) -> Record<EncryptedData> {
|
||||
Record { ..other.0 }
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@
|
|||
|
||||
## URI for postgres (using development creds here)
|
||||
# db_uri="postgres://username:password@localhost/atuin"
|
||||
# db_uri="sqlite:///config/atuin-server.db"
|
||||
|
||||
## Maximum size for one history entry
|
||||
# max_history_length = 8192
|
||||
|
|
|
@ -45,10 +45,7 @@ async fn shutdown_signal() {
|
|||
eprintln!("Shutting down gracefully...");
|
||||
}
|
||||
|
||||
pub async fn launch<Db: Database>(
|
||||
settings: Settings<Db::Settings>,
|
||||
addr: SocketAddr,
|
||||
) -> Result<()> {
|
||||
pub async fn launch<Db: Database>(settings: Settings, addr: SocketAddr) -> Result<()> {
|
||||
if settings.tls.enable {
|
||||
launch_with_tls::<Db>(settings, addr, shutdown_signal()).await
|
||||
} else {
|
||||
|
@ -64,7 +61,7 @@ pub async fn launch<Db: Database>(
|
|||
}
|
||||
|
||||
pub async fn launch_with_tcp_listener<Db: Database>(
|
||||
settings: Settings<Db::Settings>,
|
||||
settings: Settings,
|
||||
listener: TcpListener,
|
||||
shutdown: impl Future<Output = ()> + Send + 'static,
|
||||
) -> Result<()> {
|
||||
|
@ -78,7 +75,7 @@ pub async fn launch_with_tcp_listener<Db: Database>(
|
|||
}
|
||||
|
||||
async fn launch_with_tls<Db: Database>(
|
||||
settings: Settings<Db::Settings>,
|
||||
settings: Settings,
|
||||
addr: SocketAddr,
|
||||
shutdown: impl Future<Output = ()>,
|
||||
) -> Result<()> {
|
||||
|
@ -135,9 +132,7 @@ pub async fn launch_metrics_server(host: String, port: u16) -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn make_router<Db: Database>(
|
||||
settings: Settings<<Db as Database>::Settings>,
|
||||
) -> Result<Router, eyre::Error> {
|
||||
async fn make_router<Db: Database>(settings: Settings) -> Result<Router, eyre::Error> {
|
||||
let db = Db::new(&settings.db_settings)
|
||||
.await
|
||||
.wrap_err_with(|| format!("failed to connect to db: {:?}", settings.db_settings))?;
|
||||
|
|
|
@ -105,10 +105,10 @@ async fn semver(request: Request, next: Next) -> Response {
|
|||
#[derive(Clone)]
|
||||
pub struct AppState<DB: Database> {
|
||||
pub database: DB,
|
||||
pub settings: Settings<DB::Settings>,
|
||||
pub settings: Settings,
|
||||
}
|
||||
|
||||
pub fn router<DB: Database>(database: DB, settings: Settings<DB::Settings>) -> Router {
|
||||
pub fn router<DB: Database>(database: DB, settings: Settings) -> Router {
|
||||
let routes = Router::new()
|
||||
.route("/", get(handlers::index))
|
||||
.route("/healthz", get(handlers::health::health_check))
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use std::{io::prelude::*, path::PathBuf};
|
||||
|
||||
use atuin_server_database::DbSettings;
|
||||
use config::{Config, Environment, File as ConfigFile, FileFormat};
|
||||
use eyre::{Result, eyre};
|
||||
use fs_err::{File, create_dir_all};
|
||||
use serde::{Deserialize, Serialize, de::DeserializeOwned};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
static EXAMPLE_CONFIG: &str = include_str!("../server.toml");
|
||||
|
||||
|
@ -53,7 +54,7 @@ impl Default for Metrics {
|
|||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct Settings<DbSettings> {
|
||||
pub struct Settings {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub path: String,
|
||||
|
@ -78,7 +79,7 @@ pub struct Settings<DbSettings> {
|
|||
pub db_settings: DbSettings,
|
||||
}
|
||||
|
||||
impl<DbSettings: DeserializeOwned> Settings<DbSettings> {
|
||||
impl Settings {
|
||||
pub fn new() -> Result<Self> {
|
||||
let mut config_file = if let Ok(p) = std::env::var("ATUIN_CONFIG_DIR") {
|
||||
PathBuf::from(p)
|
||||
|
|
|
@ -37,12 +37,19 @@ default = ["client", "sync", "server", "clipboard", "check-update", "daemon"]
|
|||
client = ["atuin-client"]
|
||||
sync = ["atuin-client/sync"]
|
||||
daemon = ["atuin-client/daemon", "atuin-daemon"]
|
||||
server = ["atuin-server", "atuin-server-postgres"]
|
||||
server = [
|
||||
"atuin-server",
|
||||
"atuin-server-database",
|
||||
"atuin-server-postgres",
|
||||
"atuin-server-sqlite",
|
||||
]
|
||||
clipboard = ["arboard"]
|
||||
check-update = ["atuin-client/check-update"]
|
||||
|
||||
[dependencies]
|
||||
atuin-server-database = { path = "../atuin-server-database", version = "18.6.1", optional = true }
|
||||
atuin-server-postgres = { path = "../atuin-server-postgres", version = "18.6.1", optional = true }
|
||||
atuin-server-sqlite = { path = "../atuin-server-sqlite", version = "18.6.1", optional = true }
|
||||
atuin-server = { path = "../atuin-server", version = "18.6.1", optional = true }
|
||||
atuin-client = { path = "../atuin-client", version = "18.6.1", optional = true, default-features = false }
|
||||
atuin-common = { path = "../atuin-common", version = "18.6.1" }
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
use std::net::SocketAddr;
|
||||
|
||||
use atuin_server_database::DbType;
|
||||
use atuin_server_postgres::Postgres;
|
||||
use atuin_server_sqlite::Sqlite;
|
||||
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
|
||||
|
||||
use clap::Parser;
|
||||
use eyre::{Context, Result};
|
||||
use eyre::{Context, Result, eyre};
|
||||
|
||||
use atuin_server::{Settings, example_config, launch, launch_metrics_server};
|
||||
|
||||
|
@ -50,7 +52,13 @@ impl Cmd {
|
|||
));
|
||||
}
|
||||
|
||||
launch::<Postgres>(settings, addr).await
|
||||
match settings.db_settings.db_type() {
|
||||
DbType::Postgres => launch::<Postgres>(settings, addr).await,
|
||||
DbType::Sqlite => launch::<Sqlite>(settings, addr).await,
|
||||
DbType::Unknown => {
|
||||
Err(eyre!("db_uri must start with postgres:// or sqlite://"))
|
||||
}
|
||||
}
|
||||
}
|
||||
Self::DefaultConfig => {
|
||||
println!("{}", example_config());
|
||||
|
|
|
@ -3,7 +3,8 @@ use std::{env, time::Duration};
|
|||
use atuin_client::api_client;
|
||||
use atuin_common::utils::uuid_v7;
|
||||
use atuin_server::{Settings as ServerSettings, launch_with_tcp_listener};
|
||||
use atuin_server_postgres::{Postgres, PostgresSettings};
|
||||
use atuin_server_database::DbSettings;
|
||||
use atuin_server_postgres::Postgres;
|
||||
use futures_util::TryFutureExt;
|
||||
use tokio::{net::TcpListener, sync::oneshot, task::JoinHandle};
|
||||
use tracing::{Dispatch, dispatcher};
|
||||
|
@ -35,7 +36,7 @@ pub async fn start_server(path: &str) -> (String, oneshot::Sender<()>, JoinHandl
|
|||
page_size: 1100,
|
||||
register_webhook_url: None,
|
||||
register_webhook_username: String::new(),
|
||||
db_settings: PostgresSettings { db_uri },
|
||||
db_settings: DbSettings { db_uri },
|
||||
metrics: atuin_server::settings::Metrics::default(),
|
||||
tls: atuin_server::settings::Tls::default(),
|
||||
mail: atuin_server::settings::Mail::default(),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue