diff --git a/rpc/call.go b/rpc/call.go new file mode 100644 index 0000000..f7ad220 --- /dev/null +++ b/rpc/call.go @@ -0,0 +1,82 @@ +package rpc + +import ( + "context" + "encoding/json" + "fmt" + + "github.com/google/uuid" + "nhooyr.io/websocket" +) + +func (s *server) Call(ctx context.Context, c *websocket.Conn, method string, params, result any) (*Response, error) { + id := uuid.New().String() + resp := make(chan *Response, 1) + + var dataParams []byte + var err error + if params != nil { + dataParams, err = json.Marshal(params) + if err != nil { + return nil, fmt.Errorf("Error Marshalling Params: %w", err) + } + } + + rawParams := json.RawMessage(dataParams) + + req := Request{ + ID: id, + Method: method, + Params: &rawParams, + } + + reqData, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("Error Marshalling Request: %w", err) + } + + // Add Call to Request Map + func() { + s.requestMutex.Lock() + defer s.requestMutex.Unlock() + + s.requests[id] = resp + }() + + // Remove Call from Request map + defer func() { + s.requestMutex.Lock() + defer s.requestMutex.Unlock() + + delete(s.requests, id) + }() + + // Write Request + err = c.Write(ctx, websocket.MessageText, reqData) + if err != nil { + return nil, fmt.Errorf("Error Writing Request: %w", err) + } + + // Wait for Response, TODO add Select Timeout + response := <-resp + + if response.Error != nil { + return response, fmt.Errorf("Call Error: %w", err) + } + + if result == nil { + return response, nil + } + + if response.Result == nil { + return response, fmt.Errorf("Got Empty Result") + } + + err = json.Unmarshal(*response.Result, &result) + if err != nil { + return response, fmt.Errorf("Error Parsing Result: %w", err) + } + return response, nil +} + +// TODO Call with Multiple Response (Chunked file upload) diff --git a/rpc/error.go b/rpc/error.go new file mode 100644 index 0000000..7f46231 --- /dev/null +++ b/rpc/error.go @@ -0,0 +1,41 @@ +package rpc + +import ( + "context" + "encoding/json" + "log/slog" + + "nhooyr.io/websocket" +) + +func respondError(ctx context.Context, c *websocket.Conn, id string, code int64, err error, data any) { + slog.ErrorContext(ctx, "Responding to Websocket Request With Error", "err", err, "id", id, "code", code, "data", data) + rData := []byte{} + if data != nil { + rData, err = json.Marshal(data) + if err != nil { + slog.ErrorContext(ctx, "Error Marshalling Error Data", "err", err) + return + } + } + + raw := json.RawMessage(rData) + + resp, err := json.Marshal(Response{ + ID: id, + Error: &Error{ + Code: code, + Message: err.Error(), + Data: &raw, + }, + }) + if err != nil { + slog.ErrorContext(ctx, "Error Marshalling Error Response", "err", err) + return + } + err = c.Write(ctx, websocket.MessageText, resp) + if err != nil { + slog.ErrorContext(ctx, "Error Sending Error Response", "err", err) + return + } +} diff --git a/rpc/request.go b/rpc/request.go new file mode 100644 index 0000000..bfa1aab --- /dev/null +++ b/rpc/request.go @@ -0,0 +1,61 @@ +package rpc + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + + "nhooyr.io/websocket" +) + +func (s *server) handleRequest(ctx context.Context, c *websocket.Conn, data []byte) { + var request Request + err := json.Unmarshal(data, &request) + if err != nil { + respondError(ctx, c, "", ERROR_JRPC2_PARSE_ERROR, fmt.Errorf("Error Parsing Request: %w", err), nil) + return + } + + // Get the Requested function + fun, ok := s.methods[request.Method] + if !ok { + respondError(ctx, c, request.ID, ERROR_JRPC2_METHOD_NOT_FOUND, fmt.Errorf("Method Not Found"), nil) + return + } + + reqCtx, cancel := context.WithCancel(ctx) + defer cancel() + + // Run the Requested function + result, err := fun(reqCtx, request) + if err != nil { + respondError(ctx, c, request.ID, 1000, fmt.Errorf("Method Error: %w", err), result) + return + } + + rData := json.RawMessage{} + if data != nil { + rData, err = json.Marshal(result) + if err != nil { + respondError(ctx, c, request.ID, ERROR_JRPC2_INTERNAL, fmt.Errorf("Error Marshalling Response Data: %w", err), nil) + return + } + } + slog.InfoContext(ctx, "response data", "rdata", rData) + + resp, err := json.Marshal(Response{ + ID: request.ID, + Result: &rData, + }) + if err != nil { + respondError(ctx, c, request.ID, ERROR_JRPC2_INTERNAL, fmt.Errorf("Error Marshalling Response: %w", err), nil) + return + } + + err = c.Write(ctx, websocket.MessageText, resp) + if err != nil { + respondError(ctx, c, request.ID, ERROR_JRPC2_INTERNAL, fmt.Errorf("Error Sending Response: %w", err), nil) + return + } +} diff --git a/rpc/response.go b/rpc/response.go new file mode 100644 index 0000000..52d2239 --- /dev/null +++ b/rpc/response.go @@ -0,0 +1,27 @@ +package rpc + +import ( + "context" + "encoding/json" + "log/slog" +) + +func (s *server) handleResponse(ctx context.Context, data []byte) { + var response Response + err := json.Unmarshal(data, &response) + if err != nil { + slog.ErrorContext(ctx, "Cannot Parse Response", "err", err) + return + } + + s.requestMutex.Lock() + defer s.requestMutex.Unlock() + r, ok := s.requests[response.ID] + if !ok { + slog.ErrorContext(ctx, "Unknown Response", "response", response) + return + } + + // Send Response to Original Caller + r <- &response +} diff --git a/rpc/server.go b/rpc/server.go new file mode 100644 index 0000000..0f2ef1a --- /dev/null +++ b/rpc/server.go @@ -0,0 +1,42 @@ +package rpc + +import ( + "context" + "encoding/json" + "fmt" + + "nhooyr.io/websocket" +) + +const ERROR_JRPC2_PARSE_ERROR = -32700 +const ERROR_JRPC2_METHOD_NOT_FOUND = -32601 +const ERROR_JRPC2_INTERNAL = -32603 + +func NewServer() *server { + return &server{ + methods: make(map[string]func(context.Context, Request) (any, error)), + requests: make(map[string]chan *Response, 1), + } +} + +func (s *server) RegisterMethod(name string, method func(context.Context, Request) (any, error)) { + s.methods[name] = method +} + +// TODO Method With Multiple Responses + +func (s *server) HandleMessage(ctx context.Context, c *websocket.Conn, data []byte) { + var message Message + err := json.Unmarshal(data, &message) + if err != nil { + respondError(ctx, c, "", ERROR_JRPC2_PARSE_ERROR, fmt.Errorf("Error Parsing Message: %w", err), nil) + return + } + + // Check if this is a Request or a Response with the Existance of the Method field + if message.Method != nil { + s.handleRequest(ctx, c, data) + } else { + s.handleResponse(ctx, data) + } +} diff --git a/rpc/types.go b/rpc/types.go new file mode 100644 index 0000000..856c72c --- /dev/null +++ b/rpc/types.go @@ -0,0 +1,36 @@ +package rpc + +import ( + "context" + "encoding/json" + "sync" +) + +type server struct { + methods map[string]func(context.Context, Request) (any, error) + requestMutex sync.Mutex + requests map[string]chan *Response +} + +type Message struct { + ID string `json:"id"` + Method *string `json:"method,omitempty"` +} + +type Request struct { + Method string `json:"method"` + Params *json.RawMessage `json:"params,omitempty"` + ID string `json:"id"` +} + +type Response struct { + ID string `json:"id"` + Result *json.RawMessage `json:"result,omitempty"` + Error *Error `json:"error,omitempty"` +} + +type Error struct { + Code int64 `json:"code"` + Message string `json:"message"` + Data *json.RawMessage `json:"data,omitempty"` +}