diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..87f4de3 --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,48 @@ +mod network; +mod system; + +use crate::state::RpcState; +use jsonrpsee::{ + types::{error::ErrorCode, ErrorObject}, + RpcModule, +}; + +use custom_error::custom_error; +use tracing::info; + +custom_error! { pub ApiError + InvalidParams = "Invalid Parameters", + Leet = "1337", +} + +impl Into> for ApiError { + fn into(self) -> ErrorObject<'static> { + match self { + Self::InvalidParams => ErrorCode::InvalidParams, + Self::Leet => ErrorCode::ServerError(1337), + _ => ErrorCode::InternalError, + } + .into() + } +} + +pub fn new_rpc_module(state: RpcState) -> RpcModule { + let mut module = RpcModule::new(state); + + module + .register_method("ping", |_, _| { + info!("ping called"); + "pong" + }) + .unwrap(); + + module + .register_method("System.GetUsers", system::get_users) + .unwrap(); + + module + .register_method("Network.GetStaticRoutes", network::get_static_routes) + .unwrap(); + + module +} diff --git a/src/api/network.rs b/src/api/network.rs new file mode 100644 index 0000000..bf92118 --- /dev/null +++ b/src/api/network.rs @@ -0,0 +1,13 @@ +use jsonrpsee::types::Params; + +use crate::{definitions::network::StaticRoute, state::RpcState}; + +use super::ApiError; + +pub fn get_static_routes(_: Params, state: &RpcState) -> Result, ApiError> { + Ok(state + .config_manager + .get_pending_config() + .network + .static_routes) +} diff --git a/src/api/system.rs b/src/api/system.rs new file mode 100644 index 0000000..01ec35e --- /dev/null +++ b/src/api/system.rs @@ -0,0 +1,10 @@ +use std::collections::HashMap; + +use crate::{definitions::system::User, state::RpcState}; +use jsonrpsee::types::Params; + +use super::ApiError; + +pub fn get_users(_: Params, state: &RpcState) -> Result, ApiError> { + Ok(state.config_manager.get_pending_config().system.users) +} diff --git a/src/main.rs b/src/main.rs index cf69770..9fda6f9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ use std::{ sync::{Arc, RwLock}, }; +use crate::state::RpcState; use axum::{middleware, Router}; use config_manager::ConfigManager; use state::AppState; @@ -15,6 +16,7 @@ use tracing::info; use tracing_subscriber; use web::auth::SessionState; +mod api; mod config_manager; mod definitions; mod state; @@ -41,12 +43,17 @@ async fn main() { // TODO Check Config Manager Setup Error let config_manager = ConfigManager::new().unwrap(); + let session_state = SessionState { + sessions: Arc::new(RwLock::new(HashMap::new())), + }; let app_state = AppState { - config_manager, - session_state: SessionState { - sessions: Arc::new(RwLock::new(HashMap::new())), - }, + config_manager: config_manager.clone(), + session_state: session_state.clone(), + rpc_module: api::new_rpc_module(RpcState { + config_manager, + session_state, + }), }; // Note: The Router Works Bottom Up, So the auth middleware will only applies to everything above it. diff --git a/src/state.rs b/src/state.rs index 26f5551..fb63902 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,3 +1,5 @@ +use jsonrpsee::RpcModule; + use super::config_manager::ConfigManager; use super::web::auth::SessionState; @@ -5,4 +7,11 @@ use super::web::auth::SessionState; pub struct AppState { pub config_manager: ConfigManager, pub session_state: SessionState, + pub rpc_module: RpcModule, +} + +#[derive(Clone)] +pub struct RpcState { + pub config_manager: ConfigManager, + pub session_state: SessionState, } diff --git a/src/web/rpc.rs b/src/web/rpc.rs index 3b776e5..d980b98 100644 --- a/src/web/rpc.rs +++ b/src/web/rpc.rs @@ -1,42 +1,60 @@ -use std::collections::HashMap; - -use super::super::definitions::network::StaticRoute; -use super::super::definitions::system::User; -use super::super::AppState; +use crate::AppState; use axum::routing::post; use axum::{Json, Router}; -use jsonrpsee::types::Params; -use tower_cookies::{Cookie, Cookies}; +use jsonrpsee::core::traits::ToRpcParams; +use jsonrpsee::core::Error; +use serde::{Deserialize, Serialize}; +use serde_json::value::RawValue; -use axum::{ - extract::Extension, - extract::State, - http::{Request, StatusCode}, - middleware::{self, Next}, - response::{IntoResponse, Response}, -}; - -use jsonrpsee::server::{RpcModule, Server}; +use axum::{extract::Extension, extract::State, response::IntoResponse}; use tracing::info; -use custom_error::custom_error; - -custom_error! { ApiError - BadRequest = "Bad Request Parameters", +// TODO fix this "workaround" +struct ParamConverter { + params: Option>, } -struct RpcRequest<'a> { +impl ToRpcParams for ParamConverter { + fn to_rpc_params(self) -> Result>, Error> { + let s = String::from_utf8(serde_json::to_vec(&self.params)?); + match s { + Ok(s) => { + return RawValue::from_string(s) + .map(Some) + .map_err(Error::ParseError) + } + // TODO make this a Parse error wrapping Utf8Error + Err(err) => return Err(Error::AlreadyStopped), + } + } +} + +#[derive(Deserialize)] +struct RpcRequest { id: i64, - params: Params<'a>, + params: Option>, jsonrpc: String, method: String, } -struct RpcResponse<'a> { +#[derive(Clone, Deserialize, Serialize)] +struct RpcResponse { id: i64, - payload: Params<'a>, jsonrpc: String, + #[serde(skip_serializing_if = "Option::is_none")] + result: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +#[derive(Clone, Deserialize, Serialize)] + +struct RpcErrorObject { + code: i64, + message: String, + #[serde(skip_serializing_if = "Option::is_none")] + data: Option>, } pub fn routes() -> Router { @@ -46,49 +64,38 @@ pub fn routes() -> Router { async fn api_handler( State(state): State, session: Extension, - Json(rpc_request): Json>, + body: String, ) -> impl IntoResponse { info!("api hit! user: {:?}", session.username); - let module = RpcModule::new(state); - module - .register_method("say_hello", |_, _| { - println!("say_hello method called!"); - "Hello there!!" - }) - .unwrap(); - module - .register_method("System.GetUsers", get_users) - .unwrap(); - module - .register_method("Network.GetStaticRoutes", get_static_routes) - .unwrap(); + // TODO handle Parse Error + let req: RpcRequest = serde_json::from_str(&body).unwrap(); - let res = module.call(&rpc_request.method, rpc_request.params).await; + // TODO check version + + let params = ParamConverter { params: req.params }; + + // TODO check Permissions for method here? + + let res: Result>, Error> = + state.rpc_module.call(&req.method, params).await; match res { - Ok(res) => RpcResponse { - id: rpc_request.id, - jsonrpc: rpc_request.jsonrpc, - payload: res, - }, - // TODO make Error Response - Err(err) => RpcResponse { - id: rpc_request.id, - jsonrpc: rpc_request.jsonrpc, - payload: res, - }, + Ok(res) => Json(RpcResponse { + id: req.id, + jsonrpc: req.jsonrpc, + result: res, + error: None, + }), + Err(err) => Json(RpcResponse { + id: req.id, + jsonrpc: req.jsonrpc, + result: None, + error: Some(RpcErrorObject { + code: 10, + message: err.to_string(), + data: None, + }), + }), } } - -fn get_users(_: Params, state: &AppState) -> Result, String> { - Ok(state.config_manager.get_pending_config().system.users) -} - -fn get_static_routes(_: Params, state: &AppState) -> Result, String> { - Ok(state - .config_manager - .get_pending_config() - .network - .static_routes) -}