diff --git a/Cargo.lock b/Cargo.lock index b241c0e..87372cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -534,6 +534,7 @@ dependencies = [ name = "nfsense" version = "0.1.0" dependencies = [ + "async-trait", "axum", "custom_error", "ipnet", @@ -541,6 +542,7 @@ dependencies = [ "pwhash", "serde", "serde_json", + "thiserror", "tokio", "tower-cookies", "tower-http", @@ -985,6 +987,26 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" +[[package]] +name = "thiserror" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "thread_local" version = "1.1.7" diff --git a/Cargo.toml b/Cargo.toml index d27ab4d..4a65bef 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-trait = "0.1.74" axum = "0.6.20" custom_error = "1.9.2" ipnet = { version = "2.8.0", features = ["serde"] } @@ -13,6 +14,7 @@ macaddr = { version = "1.0.1", features = ["serde"] } pwhash = "1.0.0" serde = { version = "1.0.189", features = ["derive"] } serde_json = "1.0.107" +thiserror = "1.0.50" tokio = { version = "1.33.0", features = ["full"] } tower-cookies = "0.9.0" tower-http = "0.4.4" diff --git a/src/json_rpc/error.rs b/src/json_rpc/error.rs new file mode 100644 index 0000000..21fb2d8 --- /dev/null +++ b/src/json_rpc/error.rs @@ -0,0 +1,97 @@ +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +/// Constants for [error object](https://www.jsonrpc.org/specification#error_object) +pub const INVALID_REQUEST: i32 = -32600; +pub const METHOD_NOT_FOUND: i32 = -32601; +pub const INVALID_PARAMS: i32 = -32602; +pub const INTERNAL_ERROR: i32 = -32603; +pub const PARSE_ERROR: i32 = -32700; + +#[derive(Debug)] +pub enum JsonRpcErrorReason { + ParseError, + InvalidRequest, + MethodNotFound, + InvalidParams, + InternalError, + /// -32000 to -32099 + ServerError(i32), +} + +impl std::fmt::Display for JsonRpcErrorReason { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + JsonRpcErrorReason::ParseError => write!(f, "Parse error"), + JsonRpcErrorReason::InvalidRequest => write!(f, "Invalid Request"), + JsonRpcErrorReason::MethodNotFound => write!(f, "Method not found"), + JsonRpcErrorReason::InvalidParams => write!(f, "Invalid params"), + JsonRpcErrorReason::InternalError => write!(f, "Internal error"), + JsonRpcErrorReason::ServerError(code) => write!(f, "Server error: {}", code), + } + } +} + +impl From for i32 { + fn from(reason: JsonRpcErrorReason) -> i32 { + match reason { + JsonRpcErrorReason::ParseError => PARSE_ERROR, + JsonRpcErrorReason::InvalidRequest => INVALID_REQUEST, + JsonRpcErrorReason::MethodNotFound => METHOD_NOT_FOUND, + JsonRpcErrorReason::InvalidParams => INVALID_PARAMS, + JsonRpcErrorReason::InternalError => INTERNAL_ERROR, + JsonRpcErrorReason::ServerError(code) => code, + } + } +} + +impl JsonRpcErrorReason { + fn new(code: i32) -> Self { + match code { + PARSE_ERROR => Self::ParseError, + INVALID_REQUEST => Self::InvalidRequest, + METHOD_NOT_FOUND => Self::MethodNotFound, + INVALID_PARAMS => Self::InvalidParams, + INTERNAL_ERROR => Self::InternalError, + other => Self::ServerError(other), + } + } +} + +#[derive(Debug, Error, Serialize, Deserialize)] +pub struct JsonRpcError { + code: i32, + message: String, + data: serde_json::Value, +} + +impl JsonRpcError { + pub fn new(code: JsonRpcErrorReason, message: String, data: serde_json::Value) -> Self { + Self { + code: code.into(), + message, + data, + } + } +} + +impl std::fmt::Display for JsonRpcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}: {}", + JsonRpcErrorReason::new(self.code), + self.message + ) + } +} + +impl JsonRpcError { + pub fn error_reason(&self) -> JsonRpcErrorReason { + JsonRpcErrorReason::new(self.code) + } + + pub fn code(&self) -> i32 { + self.code + } +} diff --git a/src/json_rpc/mod.rs b/src/json_rpc/mod.rs new file mode 100644 index 0000000..c5e1cb1 --- /dev/null +++ b/src/json_rpc/mod.rs @@ -0,0 +1,191 @@ +// Copied from https://github.com/ralexstokes/axum-json-rpc since it needed minor modifications + +use axum::body::HttpBody; +use axum::extract::{FromRequest, FromRequestParts}; +use axum::http::Request; +use axum::response::{IntoResponse, Response}; +use axum::{BoxError, Json}; +use error::{JsonRpcError, JsonRpcErrorReason}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +pub mod error; + +/// Hack until [try_trait_v2](https://github.com/rust-lang/rust/issues/84277) is not stabilized +pub type JsonRpcResult = Result; + +#[derive(Deserialize, Debug)] +#[serde(deny_unknown_fields)] +struct JsonRpcRequest { + id: i64, + jsonrpc: String, + method: String, + params: Option, +} + +/// Parses a JSON-RPC request, and returns the request ID, the method name, and the parameters. +/// If the request is invalid, returns an error. +/// ```rust +/// use axum_jrpc::{JsonRpcResult, JsonRpcExtractor, JsonRpcResponse}; +/// +/// fn router(req: JsonRpcExtractor) -> JsonRpcResult { +/// let req_id = req.get_request_id()?; +/// let method = req.method(); +/// match method { +/// "add" => { +/// let params: [i32;2] = req.parse_params()?; +/// return Ok(JsonRpcResponse::success(req_id, params[0] + params[1])) +/// } +/// m => Ok(req.method_not_found(m)) +/// } +/// } +/// ``` +#[derive(Debug)] +pub struct JsonRpcExtractor { + pub parsed: Value, + pub method: String, + pub id: i64, +} + +impl JsonRpcExtractor { + pub fn get_request_id(&self) -> i64 { + self.id + } + + pub fn parse_params(self) -> Result { + let value = serde_json::from_value(self.parsed); + match value { + Ok(v) => Ok(v), + Err(e) => { + let error = JsonRpcError::new( + JsonRpcErrorReason::InvalidParams, + e.to_string(), + Value::Null, + ); + Err(JsonRpcResponse::error(self.id, error)) + } + } + } + + pub fn method(&self) -> &str { + &self.method + } + + pub fn method_not_found(&self, method: &str) -> JsonRpcResponse { + let error = JsonRpcError::new( + JsonRpcErrorReason::MethodNotFound, + format!("Method `{}` not found", method), + Value::Null, + ); + JsonRpcResponse::error(self.id, error) + } +} + +#[async_trait::async_trait] +impl FromRequest for JsonRpcExtractor +where + S: Send + Sync, + B: Send + 'static, + B: HttpBody + Send, + B::Data: Send, + B::Error: Into, +{ + type Rejection = JsonRpcResponse; + + async fn from_request(req: Request, state: &S) -> Result { + let json = Json::from_request(req, state).await; + let parsed: JsonRpcRequest = match json { + Ok(a) => a.0, + Err(e) => { + return Err(JsonRpcResponse { + id: 0, + jsonrpc: "2.0".to_owned(), + result: JsonRpcAnswer::Error(JsonRpcError::new( + JsonRpcErrorReason::InvalidRequest, + e.to_string(), + Value::Null, + )), + }) + } + }; + if parsed.jsonrpc != "2.0" { + return Err(JsonRpcResponse { + id: parsed.id, + jsonrpc: "2.0".to_owned(), + result: JsonRpcAnswer::Error(JsonRpcError::new( + JsonRpcErrorReason::InvalidRequest, + "Invalid jsonrpc version".to_owned(), + Value::Null, + )), + }); + } + Ok(Self { + parsed: match parsed.params { + Some(p) => p, + None => Value::Null, + }, + method: parsed.method, + id: parsed.id, + }) + } +} + +#[derive(Serialize, Debug, Deserialize)] +/// A JSON-RPC response. +pub struct JsonRpcResponse { + jsonrpc: String, + pub result: JsonRpcAnswer, + /// The request ID. + id: i64, +} + +impl JsonRpcResponse { + /// Returns a response with the given result + /// Returns JsonRpcError if the `result` is invalid input for [`serde_json::to_value`] + pub fn success(id: i64, result: T) -> Self { + let result = match serde_json::to_value(result) { + Ok(v) => v, + Err(e) => { + let err = JsonRpcError::new( + JsonRpcErrorReason::InternalError, + e.to_string(), + Value::Null, + ); + return JsonRpcResponse { + id, + jsonrpc: "2.0".to_owned(), + result: JsonRpcAnswer::Error(err), + }; + } + }; + + JsonRpcResponse { + id, + jsonrpc: "2.0".to_owned(), + result: JsonRpcAnswer::Result(result), + } + } + + pub fn error(id: i64, error: JsonRpcError) -> Self { + JsonRpcResponse { + id, + jsonrpc: "2.0".to_owned(), + result: JsonRpcAnswer::Error(error), + } + } +} + +impl IntoResponse for JsonRpcResponse { + fn into_response(self) -> Response { + Json(self).into_response() + } +} + +#[derive(Serialize, Debug, Deserialize)] +#[serde(untagged)] +/// JsonRpc [response object](https://www.jsonrpc.org/specification#response_object) +pub enum JsonRpcAnswer { + Result(Value), + Error(JsonRpcError), +} diff --git a/src/main.rs b/src/main.rs index cf69770..1c543fd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -17,6 +17,7 @@ use web::auth::SessionState; mod config_manager; mod definitions; +mod json_rpc; mod state; mod web; diff --git a/src/web/rpc.rs b/src/web/rpc.rs index daae823..cab913a 100644 --- a/src/web/rpc.rs +++ b/src/web/rpc.rs @@ -1,3 +1,4 @@ +use super::super::json_rpc::{JsonRpcExtractor, JsonRpcResponse, JsonRpcResult}; use super::super::AppState; use axum::routing::post; use axum::{Json, Router}; @@ -20,7 +21,22 @@ pub fn routes() -> Router { async fn api_handler( State(state): State, session: Extension, -) -> impl IntoResponse { + req: JsonRpcExtractor, +) -> JsonRpcResult { info!("api hit! user: {:?}", session.username); - StatusCode::NOT_IMPLEMENTED + let req_id = req.get_request_id(); + let method = req.method(); + let response = match method { + "add" => { + let params: [i32; 2] = req.parse_params()?; + JsonRpcResponse::success(req_id, params[0] + params[1]) + } + "System.GetUsers" => JsonRpcResponse::success( + req_id, + state.config_manager.get_pending_config().system.users, + ), + m => req.method_not_found(m), + }; + + Ok(response) }