restructure project

This commit is contained in:
Samuel Lorch 2023-03-26 18:50:18 +02:00
parent dd2db438f3
commit 2ca35d4461
46 changed files with 158 additions and 84 deletions

View file

@ -0,0 +1,20 @@
package firewall
import (
"context"
"nfsense.net/nfsense/internal/definitions"
)
type GetAddressesParameters struct {
}
type GetAddressesResult struct {
Addresses map[string]definitions.Address
}
func (f *Firewall) GetAddresses(ctx context.Context, params GetForwardRulesParameters) (GetAddressesResult, error) {
return GetAddressesResult{
Addresses: f.Conf.Firewall.Addresses,
}, nil
}

View file

@ -0,0 +1,20 @@
package firewall
import (
"context"
"nfsense.net/nfsense/internal/definitions"
)
type GetDestinationNATRulesParameters struct {
}
type GetDestinationNATRulesResult struct {
DestinationNATRules []definitions.DestinationNATRule
}
func (f *Firewall) GetDestinationNATRules(ctx context.Context, params GetForwardRulesParameters) (GetDestinationNATRulesResult, error) {
return GetDestinationNATRulesResult{
DestinationNATRules: f.Conf.Firewall.DestinationNATRules,
}, nil
}

View file

@ -0,0 +1,9 @@
package firewall
import (
"nfsense.net/nfsense/internal/definitions"
)
type Firewall struct {
Conf *definitions.Config
}

View file

@ -0,0 +1,20 @@
package firewall
import (
"context"
"nfsense.net/nfsense/internal/definitions"
)
type GetForwardRulesParameters struct {
}
type GetForwardRulesResult struct {
ForwardRules []definitions.ForwardRule
}
func (f *Firewall) GetForwardRules(ctx context.Context, params GetForwardRulesParameters) (GetForwardRulesResult, error) {
return GetForwardRulesResult{
ForwardRules: f.Conf.Firewall.ForwardRules,
}, nil
}

View file

@ -0,0 +1,20 @@
package firewall
import (
"context"
"nfsense.net/nfsense/internal/definitions"
)
type GetServicesParameters struct {
}
type GetServicesResult struct {
Services map[string]definitions.Service
}
func (f *Firewall) GetServices(ctx context.Context, params GetForwardRulesParameters) (GetServicesResult, error) {
return GetServicesResult{
Services: f.Conf.Firewall.Services,
}, nil
}

View file

@ -0,0 +1,20 @@
package firewall
import (
"context"
"nfsense.net/nfsense/internal/definitions"
)
type GetSourceNATRulesParameters struct {
}
type GetSourceNATRulesResult struct {
SourceNATRules []definitions.SourceNATRule
}
func (f *Firewall) GetSourceNATRules(ctx context.Context, params GetForwardRulesParameters) (GetSourceNATRulesResult, error) {
return GetSourceNATRulesResult{
SourceNATRules: f.Conf.Firewall.SourceNATRules,
}, nil
}

View file

@ -0,0 +1,53 @@
package definitions
import (
"encoding/json"
"net/netip"
"go4.org/netipx"
)
type Address struct {
Type AddressType `json:"type" validate:"min=0,max=3"`
Comment string `json:"comment,omitempty"`
Host *netip.Addr `json:"host,omitempty" validate:"excluded_unless=Type 0"`
Range *netipx.IPRange `json:"range,omitempty" validate:"excluded_unless=Type 1"`
Network *IPNet `json:"network,omitempty" validate:"excluded_unless=Type 2"`
Children *[]string `json:"children,omitempty"`
}
type AddressType int
const (
Host AddressType = iota
Range
Network
AddressGroup
)
func (t AddressType) String() string {
return [...]string{"host", "range", "network", "group"}[t]
}
func (t *AddressType) FromString(input string) AddressType {
return map[string]AddressType{
"host": Host,
"range": Range,
"network": Network,
"group": AddressGroup,
}[input]
}
func (t AddressType) MarshalJSON() ([]byte, error) {
return json.Marshal(t.String())
}
func (t *AddressType) UnmarshalJSON(b []byte) error {
var s string
err := json.Unmarshal(b, &s)
if err != nil {
return err
}
*t = t.FromString(s)
return nil
}

View file

@ -0,0 +1,36 @@
package definitions
import (
"fmt"
"github.com/go-playground/validator/v10"
"golang.org/x/exp/slog"
)
type Config struct {
ConfigVersion uint64 `json:"config_version" validate:"required,eq=1"`
Firewall Firewall `json:"firewall" validate:"required,dive"`
}
func ValidateConfig(conf *Config) error {
val := validator.New()
val.RegisterValidation("test", nilIfOtherNil)
return val.Struct(conf)
}
func nilIfOtherNil(fl validator.FieldLevel) bool {
slog.Info("Start", "field", fl.FieldName(), "param", fl.Param())
if !fl.Field().IsNil() {
slog.Info("Field is not nil", "field", fl.FieldName())
f := fl.Parent().FieldByName(fl.Param())
if f.IsZero() {
panic(fmt.Errorf("Param %v is not a Valid Field", fl.Param()))
}
if !f.IsNil() {
slog.Info("Fail", "field", fl.FieldName(), "param", fl.Param())
return false
}
}
slog.Info("Success", "field", fl.FieldName(), "param", fl.Param())
return true
}

View file

@ -0,0 +1,7 @@
package definitions
type DestinationNATRule struct {
Rule
Address string `json:"address,omitempty"`
Service string `json:"service,omitempty"`
}

View file

@ -0,0 +1,9 @@
package definitions
type Firewall struct {
ForwardRules []ForwardRule `json:"forward_rules" validate:"required,dive"`
DestinationNATRules []DestinationNATRule `json:"destination_nat_rules" validate:"required,dive"`
SourceNATRules []SourceNATRule `json:"source_nat_rules" validate:"required,dive"`
Addresses map[string]Address `json:"addresses" validate:"required,dive"`
Services map[string]Service `json:"services" validate:"required,dive"`
}

View file

@ -0,0 +1,30 @@
package definitions
import (
"encoding/json"
"net"
)
type IPNet struct {
net.IPNet
}
// MarshalJSON for IPNet
func (i IPNet) MarshalJSON() ([]byte, error) {
return json.Marshal(i.String())
}
// UnmarshalJSON for IPNet
func (i *IPNet) UnmarshalJSON(b []byte) error {
var s string
if err := json.Unmarshal(b, &s); err != nil {
return err
}
_, ipnet, err := net.ParseCIDR(s)
if err != nil {
return err
}
i.IPNet = *ipnet
return nil
}

View file

@ -0,0 +1,8 @@
package definitions
type Match struct {
TCPDestinationPort uint64 `json:"tcp_destination_port,omitempty"`
Services []string `json:"services,omitempty"`
SourceAddresses []string `json:"source_addresses,omitempty"`
DestinationAddresses []string `json:"destination_addresses,omitempty"`
}

View file

@ -0,0 +1,50 @@
package definitions
import "encoding/json"
type Rule struct {
ID uint64 `json:"id" validate:"required,gt=0"`
Name string `json:"name" validate:"required"`
Match Match `json:"match" validate:"required,dive"`
Comment string `json:"comment,omitempty"`
Counter bool `json:"counter,omitempty"`
}
type ForwardRule struct {
Rule
Verdict Verdict `json:"verdict" validate:"min=0,max=2"`
}
type Verdict int
const (
Accept Verdict = iota
Drop
Continue
)
func (t Verdict) String() string {
return [...]string{"accept", "drop", "continue"}[t]
}
func (t *Verdict) FromString(input string) Verdict {
return map[string]Verdict{
"accept": Accept,
"drop": Drop,
"continue": Continue,
}[input]
}
func (t Verdict) MarshalJSON() ([]byte, error) {
return json.Marshal(t.String())
}
func (t *Verdict) UnmarshalJSON(b []byte) error {
var s string
err := json.Unmarshal(b, &s)
if err != nil {
return err
}
*t = t.FromString(s)
return nil
}

View file

@ -0,0 +1,71 @@
package definitions
import (
"encoding/json"
"fmt"
)
type Service struct {
Type ServiceType `json:"type" validate:"min=0,max=3"`
Comment string `json:"comment,omitempty"`
SPortStart *uint32 `json:"sport_start,omitempty" validate:"excluded_unless=Type 0|excluded_unless=Type 1"`
SPortEnd *uint32 `json:"sport_end,omitempty"`
DPortStart *uint32 `json:"dport_start,omitempty" validate:"excluded_unless=Type 0|excluded_unless=Type 1"`
DPortEnd *uint32 `json:"dport_end,omitempty"`
ICMPCode *uint32 `json:"icmp_code,omitempty" validate:"excluded_unless=Type 2"`
Children *[]string `json:"children,omitempty"`
}
func (s Service) GetSPort() string {
if s.SPortStart == nil || *s.SPortStart == 0 {
return ""
} else if s.SPortEnd == nil || *s.SPortEnd == 0 {
return fmt.Sprintf("%d", *s.SPortStart)
}
return fmt.Sprintf("%d - %d", *s.SPortStart, *s.SPortEnd)
}
func (s Service) GetDPort() string {
if s.DPortStart == nil || *s.DPortStart == 0 {
return ""
} else if s.DPortEnd == nil || *s.DPortEnd == 0 {
return fmt.Sprintf("%d", *s.DPortStart)
}
return fmt.Sprintf("%d - %d", *s.DPortStart, *s.DPortEnd)
}
type ServiceType int
const (
TCP ServiceType = iota
UDP
ICMP
ServiceGroup
)
func (t ServiceType) String() string {
return [...]string{"tcp", "udp", "icmp", "group"}[t]
}
func (t *ServiceType) FromString(input string) ServiceType {
return map[string]ServiceType{
"tcp": TCP,
"udp": UDP,
"icmp": ICMP,
"group": ServiceGroup,
}[input]
}
func (t ServiceType) MarshalJSON() ([]byte, error) {
return json.Marshal(t.String())
}
func (t *ServiceType) UnmarshalJSON(b []byte) error {
var s string
err := json.Unmarshal(b, &s)
if err != nil {
return err
}
*t = t.FromString(s)
return nil
}

View file

@ -0,0 +1,42 @@
package definitions
import "encoding/json"
type SourceNATRule struct {
Rule
Type SnatType `json:"type" validate:"min=0,max=1"`
Address string `json:"address,omitempty"`
Service string `json:"service,omitempty"`
}
type SnatType int
const (
Snat SnatType = iota
Masquerade
)
func (t SnatType) String() string {
return [...]string{"snat", "masquerade"}[t]
}
func (t *SnatType) FromString(input string) SnatType {
return map[string]SnatType{
"snat": Snat,
"masquerade": Masquerade,
}[input]
}
func (t SnatType) MarshalJSON() ([]byte, error) {
return json.Marshal(t.String())
}
func (t *SnatType) UnmarshalJSON(b []byte) error {
var s string
err := json.Unmarshal(b, &s)
if err != nil {
return err
}
*t = t.FromString(s)
return nil
}

37
internal/jsonrpc/error.go Normal file
View file

@ -0,0 +1,37 @@
package jsonrpc
import (
"io"
)
type ErrorCode int
const (
ErrParse ErrorCode = -32700
ErrInvalidRequest ErrorCode = -32600
ErrMethodNotFound ErrorCode = -32601
ErrInvalidParams ErrorCode = -32602
ErrInternalError ErrorCode = -32603
// Custom
ErrRequestError ErrorCode = -32000
)
type respError struct {
Code ErrorCode `json:"code"`
Message string `json:"message"`
// cannot be omitempty because of frontend library
Data any `json:"data"`
}
func respondError(w io.Writer, id any, code ErrorCode, err error) error {
respond(w, response{
Jsonrpc: "2.0",
ID: id,
Error: &respError{
Code: code,
Message: err.Error(),
},
})
return err
}

View file

@ -0,0 +1,98 @@
package jsonrpc
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"reflect"
"runtime/debug"
"golang.org/x/exp/slog"
"nfsense.net/nfsense/internal/session"
)
type Handler struct {
methods map[string]method
maxRequestSize int64
}
func NewHandler(maxRequestSize int64) *Handler {
return &Handler{
methods: map[string]method{},
maxRequestSize: maxRequestSize,
}
}
func (h *Handler) HandleRequest(ctx context.Context, s *session.Session, r io.Reader, w io.Writer) error {
defer func() {
if r := recover(); r != nil {
slog.Error("Recovered Panic Handling JSONRPC Request", fmt.Errorf("%v", r), "stack", debug.Stack())
}
}()
var req request
bufferedRequest := new(bytes.Buffer)
reqSize, err := bufferedRequest.ReadFrom(io.LimitReader(r, h.maxRequestSize+1))
if err != nil {
return respondError(w, "", ErrInternalError, fmt.Errorf("Reading Request: %w", err))
}
if reqSize > h.maxRequestSize {
return respondError(w, "", ErrParse, fmt.Errorf("Request exceeds Max Request Size"))
}
dec := json.NewDecoder(bufferedRequest)
dec.DisallowUnknownFields()
err = dec.Decode(&req)
if err != nil {
return respondError(w, "", ErrParse, fmt.Errorf("Decodeing Request: %w", err))
}
if req.Jsonrpc != "2.0" {
return respondError(w, req.ID, ErrMethodNotFound, fmt.Errorf("Unsupported Jsonrpc version %v", req.Jsonrpc))
}
if s == nil {
return respondError(w, req.ID, 401, fmt.Errorf("Unauthorized"))
}
method, ok := h.methods[req.Method]
if !ok {
return respondError(w, req.ID, ErrMethodNotFound, fmt.Errorf("Unknown Method %v", req.Method))
}
p := reflect.New(method.inType)
paramPointer := p.Interface()
if len(req.Params) != 0 {
dec = json.NewDecoder(bytes.NewReader(req.Params))
dec.DisallowUnknownFields()
err = dec.Decode(paramPointer)
if err != nil {
return respondError(w, req.ID, ErrInvalidParams, fmt.Errorf("Decoding Parameters: %w", err))
}
}
params := make([]reflect.Value, 3)
params[0] = method.subSystem
params[1] = reflect.ValueOf(ctx)
params[2] = reflect.ValueOf(paramPointer).Elem()
defer func() {
if r := recover(); r != nil {
slog.Error("Recovered Panic Executing API Method", fmt.Errorf("%v", r), "method", req.Method, "id", req.ID, "stack", debug.Stack())
}
}()
res := method.handlerFunc.Call(params)
result := res[0].Interface()
if !res[1].IsNil() {
reqerr := res[1].Interface().(error)
slog.Error("API Method", reqerr, "method", req.Method, "id", req.ID)
respondError(w, req.ID, ErrInternalError, reqerr)
}
respondResult(w, req.ID, result)
return nil
}

View file

@ -0,0 +1,10 @@
package jsonrpc
import "reflect"
type method struct {
subSystem reflect.Value
handlerFunc reflect.Value
inType reflect.Type
outType reflect.Type
}

View file

@ -0,0 +1,46 @@
package jsonrpc
import (
"context"
"fmt"
"reflect"
)
func (h *Handler) Register(subSystemName string, s any) {
subSystem := reflect.ValueOf(s)
for i := 0; i < subSystem.NumMethod(); i++ {
m := subSystem.Type().Method(i)
funcType := m.Func.Type()
if funcType.NumIn() != 3 {
panic(fmt.Errorf("2 parameters are required %v", funcType.NumIn()))
}
if funcType.In(1) != reflect.TypeOf(new(context.Context)).Elem() {
panic(fmt.Errorf("the first argument needs to be a context.Context instead of %v ", funcType.In(1)))
}
if funcType.In(2).Kind() != reflect.Struct {
panic("the second argument needs to be a struct")
}
if funcType.NumOut() != 2 {
panic("2 return types are required")
}
if reflect.TypeOf(new(error)).Implements(funcType.Out(1)) {
panic("the second return type needs to be a error")
}
name := m.Name
if subSystemName != "" {
name = subSystemName + "." + name
}
h.methods[name] = method{
handlerFunc: m.Func,
subSystem: subSystem,
inType: funcType.In(2),
outType: funcType.Out(0),
}
}
}

View file

@ -0,0 +1,10 @@
package jsonrpc
import "encoding/json"
type request struct {
Jsonrpc string `json:"jsonrpc"`
ID any `json:"id,omitempty"`
Method string `json:"method"`
Params json.RawMessage `json:"params"`
}

View file

@ -0,0 +1,30 @@
package jsonrpc
import (
"encoding/json"
"io"
"golang.org/x/exp/slog"
)
type response struct {
Jsonrpc string `json:"jsonrpc"`
Result any `json:"result,omitempty"`
ID any `json:"id"`
Error *respError `json:"error,omitempty"`
}
func respond(w io.Writer, resp response) {
err := json.NewEncoder(w).Encode(resp)
if err != nil {
slog.Warn("write response", "err", err)
}
}
func respondResult(w io.Writer, id, res any) {
respond(w, response{
Jsonrpc: "2.0",
ID: id,
Result: res,
})
}

View file

@ -0,0 +1,36 @@
package nftables
import (
"bytes"
"fmt"
"os"
"nfsense.net/nfsense/internal/definitions"
)
func GenerateNfTablesFile(conf definitions.Config) (string, error) {
buf := new(bytes.Buffer)
err := templates.ExecuteTemplate(buf, "nftables.tmpl", conf)
if err != nil {
return "", fmt.Errorf("executing template: %w", err)
}
return buf.String(), nil
}
func ApplyNfTablesFile(content string) error {
f, err := os.Create("nftables.conf")
if err != nil {
return fmt.Errorf("creating File: %w", err)
}
_, err = f.WriteString(content + "\n")
if err != nil {
return fmt.Errorf("writing File: %w", err)
}
err = f.Sync()
if err != nil {
return fmt.Errorf("syncing File: %w", err)
}
return nil
}

110
internal/nftables/match.go Normal file
View file

@ -0,0 +1,110 @@
package nftables
import (
"fmt"
"nfsense.net/nfsense/internal/definitions"
"nfsense.net/nfsense/internal/util"
)
func GenerateMatcher(services map[string]definitions.Service, addresses map[string]definitions.Address, match definitions.Match) (string, error) {
return GenerateAddressMatcher(addresses, match) + " " + GenerateServiceMatcher(services, match), nil
}
func GenerateServiceMatcher(allServices map[string]definitions.Service, match definitions.Match) string {
serviceList := util.ResolveBaseServices(allServices, match.Services)
tcpSPorts := []string{}
tcpDPorts := []string{}
udpSPorts := []string{}
udpDPorts := []string{}
icmpCodes := []string{}
for _, service := range serviceList {
switch service.Type {
case definitions.TCP:
if service.GetSPort() != "" {
tcpSPorts = append(tcpSPorts, service.GetSPort())
}
if service.GetDPort() != "" {
tcpDPorts = append(tcpDPorts, service.GetDPort())
}
case definitions.UDP:
if service.GetSPort() != "" {
udpSPorts = append(udpSPorts, service.GetSPort())
}
if service.GetDPort() != "" {
udpDPorts = append(udpDPorts, service.GetDPort())
}
case definitions.ICMP:
icmpCodes = append(icmpCodes, fmt.Sprint(service.ICMPCode))
default:
panic("invalid service type")
}
}
res := ""
if len(tcpSPorts) != 0 {
res += "tcp sport " + util.ConvertSliceToSetString(tcpSPorts) + " "
}
if len(tcpDPorts) != 0 {
res += "tcp dport " + util.ConvertSliceToSetString(tcpDPorts) + " "
}
if len(udpSPorts) != 0 {
res += "udp sport " + util.ConvertSliceToSetString(udpSPorts) + " "
}
if len(udpDPorts) != 0 {
res += "udp dport " + util.ConvertSliceToSetString(udpDPorts) + " "
}
if len(icmpCodes) != 0 {
res += "icmp codes " + util.ConvertSliceToSetString(icmpCodes) + " "
}
return res
}
func GenerateAddressMatcher(allAddresses map[string]definitions.Address, match definitions.Match) string {
sourceAddressList := util.ResolveBaseAddresses(allAddresses, match.SourceAddresses)
destinationAddressList := util.ResolveBaseAddresses(allAddresses, match.DestinationAddresses)
sourceAddresses := []string{}
destinationAddresses := []string{}
for _, address := range sourceAddressList {
switch address.Type {
case definitions.Host:
sourceAddresses = append(sourceAddresses, address.Host.String())
case definitions.Range:
sourceAddresses = append(sourceAddresses, address.Range.String())
case definitions.Network:
sourceAddresses = append(sourceAddresses, address.Network.String())
default:
panic("invalid address type")
}
}
for _, address := range destinationAddressList {
switch address.Type {
case definitions.Host:
destinationAddresses = append(destinationAddresses, address.Host.String())
case definitions.Range:
destinationAddresses = append(destinationAddresses, address.Range.String())
case definitions.Network:
destinationAddresses = append(destinationAddresses, address.Network.String())
default:
panic("invalid address type")
}
}
res := ""
if len(sourceAddresses) != 0 {
res += "ip saddr " + util.ConvertSliceToSetString(sourceAddresses) + " "
}
if len(destinationAddresses) != 0 {
res += "ip daddr " + util.ConvertSliceToSetString(destinationAddresses) + " "
}
return res
}

View file

@ -0,0 +1,24 @@
package nftables
import (
"embed"
"text/template"
)
//go:embed template
var templateFS embed.FS
var templates *template.Template
func init() {
funcMap := template.FuncMap{
// The name "title" is what the function will be called in the template text.
"matcher": GenerateMatcher,
}
var err error
templates, err = template.New("").Funcs(funcMap).ParseFS(templateFS, "template/*.tmpl")
if err != nil {
panic(err)
}
}

View file

@ -0,0 +1,2 @@
{{ range $rule := .Firewall.DestinationNATRules }}
{{ matcher $.Firewall.Services $.Firewall.Addresses $rule.Match }}{{ if $rule.Counter }} counter{{ end }}{{ if ne $rule.Comment "" }} comment "{{ $rule.Comment }}"{{ end }}{{ end }}

View file

@ -0,0 +1,2 @@
{{range $rule := .Firewall.ForwardRules}}
{{ matcher $.Firewall.Services $.Firewall.Addresses $rule.Match }}{{ if $rule.Counter }} counter{{ end }} {{ $rule.Verdict.String }}{{ if ne $rule.Comment "" }} comment "{{ $rule.Comment }}"{{ end }}{{ end }}

View file

@ -0,0 +1,44 @@
#!/usr/sbin/nft -f
flush ruleset
# Address object ipsets
{{template "addresses.tmpl" .}}
# nfsense nftables inet (ipv4 + ipv6) table
table inet nfsense_inet {
# Inbound Rules
chain inbound {
type filter hook input priority 0; policy drop;
# Allow traffic from established and related packets, drop invalid
ct state vmap { established : accept, related : accept, invalid : drop }
# allow loopback traffic
iifname lo accept
{{template "inbound_rules.tmpl" .}}
}
# Forward Rules
chain forward {
type filter hook forward priority 0; policy drop;
# Allow traffic from established and related packets, drop invalid
ct state vmap { established : accept, related : accept, invalid : drop }
{{template "forward_rules.tmpl" .}}
}
# Destination NAT Rules
chain prerouting {
type nat hook prerouting priority -100; policy accept;
{{template "destination_nat_rules.tmpl" .}}
}
# Source NAT Rules
chain postrouting {
type nat hook postrouting priority 100; policy accept;
{{template "source_nat_rules.tmpl" .}}
}
}

View file

@ -0,0 +1,2 @@
{{ range $rule := .Firewall.SourceNATRules }}
{{ matcher $.Firewall.Services $.Firewall.Addresses $rule.Match }}{{ if $rule.Counter }} counter{{ end }}{{ if ne $rule.Comment "" }} comment "{{ $rule.Comment }}"{{ end }}{{ end }}

36
internal/server/api.go Normal file
View file

@ -0,0 +1,36 @@
package server
import (
"context"
"fmt"
"net/http"
"runtime/debug"
"time"
"golang.org/x/exp/slog"
"nfsense.net/nfsense/internal/session"
)
func HandleAPI(w http.ResponseWriter, r *http.Request) {
slog.Info("Api Handler hit")
_, s := session.GetSession(r)
if s == nil {
// Fallthrough after so that jsonrpc can still deliver a valid jsonrpc error
w.WriteHeader(http.StatusUnauthorized)
}
defer func() {
if r := recover(); r != nil {
slog.Error("Recovered Panic Handling HTTP API Request", fmt.Errorf("%v", r), "stack", debug.Stack())
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
}()
ctx, cancel := context.WithTimeout(context.WithValue(r.Context(), session.SessionKey, s), time.Second*10)
defer cancel()
err := apiHandler.HandleRequest(ctx, s, r.Body, w)
if err != nil {
slog.Error("Handling HTTP API Request", err)
}
}

53
internal/server/server.go Normal file
View file

@ -0,0 +1,53 @@
package server
import (
"context"
"errors"
"fmt"
"net/http"
"golang.org/x/exp/slog"
"nfsense.net/nfsense/internal/definitions"
"nfsense.net/nfsense/internal/jsonrpc"
"nfsense.net/nfsense/internal/session"
)
var server http.Server
var mux = http.NewServeMux()
var apiHandler *jsonrpc.Handler
var stopCleanup chan struct{}
func StartWebserver(conf *definitions.Config, _apiHandler *jsonrpc.Handler) {
server.Addr = ":8080"
server.Handler = mux
apiHandler = _apiHandler
// Routing
mux.HandleFunc("/login", HandleLogin)
mux.HandleFunc("/logout", HandleLogout)
mux.HandleFunc("/session", HandleSession)
mux.HandleFunc("/api", HandleAPI)
mux.HandleFunc("/ws/api", HandleWebsocketAPI)
mux.HandleFunc("/", HandleWebinterface)
stopCleanup = make(chan struct{})
go session.CleanupSessions(stopCleanup)
go func() {
if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
slog.Error("Webserver error", err)
}
slog.Info("Webserver Stopped")
}()
}
func ShutdownWebserver(ctx context.Context) error {
stopCleanup <- struct{}{}
err := server.Shutdown(ctx)
if err != nil {
return fmt.Errorf("Shutting down: %w", err)
}
return nil
}

View file

@ -0,0 +1,63 @@
package server
import (
"encoding/json"
"io"
"net/http"
"time"
"golang.org/x/exp/slog"
"nfsense.net/nfsense/internal/session"
)
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
func HandleLogin(w http.ResponseWriter, r *http.Request) {
buf, err := io.ReadAll(r.Body)
if err != nil {
slog.Error("Reading Body", err)
return
}
var req LoginRequest
err = json.Unmarshal(buf, &req)
if err != nil {
slog.Error("Unmarshal", err)
return
}
if req.Username == "admin" && req.Password == "12345" {
slog.Info("User Login Successfull")
session.GenerateSession(w, req.Username)
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusUnauthorized)
}
func HandleLogout(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, session.GetCookie("", time.Now()))
w.WriteHeader(http.StatusOK)
}
func HandleSession(w http.ResponseWriter, r *http.Request) {
id, s := session.GetSession(r)
if s == nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
session.ExtendSession(s)
http.SetCookie(w, session.GetCookie(id, s.Expires))
w.WriteHeader(http.StatusOK)
resp := session.SessionResponse{
CommitHash: session.CommitHash,
}
res, err := json.Marshal(resp)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write(res)
}

View file

@ -0,0 +1,9 @@
package server
import (
"net/http"
)
func HandleWebinterface(w http.ResponseWriter, r *http.Request) {
}

View file

@ -0,0 +1,61 @@
package server
import (
"bytes"
"context"
"fmt"
"net/http"
"runtime/debug"
"time"
"golang.org/x/exp/slog"
"nfsense.net/nfsense/internal/session"
"nhooyr.io/websocket"
)
func HandleWebsocketAPI(w http.ResponseWriter, r *http.Request) {
_, s := session.GetSession(r)
if s == nil {
w.WriteHeader(http.StatusUnauthorized)
return
}
ctx, cancel := context.WithCancel(context.WithValue(r.Context(), session.SessionKey, s))
defer cancel()
c, err := websocket.Accept(w, r, nil)
if err != nil {
slog.Error("Accepting Websocket Connection", err)
return
}
defer c.Close(websocket.StatusInternalError, "Unexpected Closing")
slog.Info("Accepted API Websocket Connection")
for {
_, m, err := c.Read(ctx)
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
slog.Info("API Websocket Closed Normally")
cancel()
return
} else if err != nil {
slog.Error("API Websocket Closed Unexpectedly", err)
cancel()
}
go func() {
defer func() {
if r := recover(); r != nil {
slog.Error("Recovered Panic Handling Websocket API Request", fmt.Errorf("%v", r), "stack", debug.Stack())
return
}
}()
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
err := apiHandler.HandleRequest(ctx, s, bytes.NewReader(m), w)
if err != nil {
slog.Error("Handling Websocket API Request", err)
}
}()
}
}

View file

@ -0,0 +1,10 @@
package session
import (
"net/http"
"time"
)
func GetCookie(value string, expires time.Time) *http.Cookie {
return &http.Cookie{Name: SessionCookieName, HttpOnly: true, SameSite: http.SameSiteStrictMode, Value: value, Expires: expires}
}

View file

@ -0,0 +1,93 @@
package session
import (
"net/http"
"runtime/debug"
"sync"
"time"
"github.com/google/uuid"
)
type SessionKeyType string
const SessionKey SessionKeyType = "session"
const SessionCookieName string = "session"
type Session struct {
Username string
Expires time.Time
// TODO Add []websocket.Conn pointer to close all active websockets, alternativly do this via context cancelation
}
type SessionResponse struct {
CommitHash string `json:"commit_hash"`
}
var sessionsSync sync.Mutex
var sessions map[string]*Session = map[string]*Session{}
var CommitHash = func() string {
if info, ok := debug.ReadBuildInfo(); ok {
for _, setting := range info.Settings {
if setting.Key == "vcs.revision" {
return setting.Value
}
}
}
return "asd"
}()
func ExtendSession(s *Session) {
sessionsSync.Lock()
defer sessionsSync.Unlock()
if s != nil {
s.Expires = time.Now().Add(time.Minute * 5)
}
}
func GetSession(r *http.Request) (string, *Session) {
c, err := r.Cookie("session")
if err != nil {
return "", nil
}
s, ok := sessions[c.Value]
if ok {
return c.Value, s
}
return "", nil
}
func GenerateSession(w http.ResponseWriter, username string) {
id := uuid.New().String()
expires := time.Now().Add(time.Minute * 5)
sessionsSync.Lock()
defer sessionsSync.Unlock()
sessions[id] = &Session{
Username: username,
Expires: expires,
}
http.SetCookie(w, &http.Cookie{Name: SessionCookieName, HttpOnly: true, SameSite: http.SameSiteStrictMode, Value: id, Expires: expires})
}
func CleanupSessions(stop chan struct{}) {
tick := time.NewTicker(time.Minute)
for {
select {
case <-tick.C:
ids := []string{}
sessionsSync.Lock()
for id, s := range sessions {
if time.Now().After(s.Expires) {
ids = append(ids, id)
}
}
for _, id := range ids {
delete(sessions, id)
}
sessionsSync.Unlock()
case <-stop:
return
}
}
}

35
internal/util/address.go Normal file
View file

@ -0,0 +1,35 @@
package util
import "nfsense.net/nfsense/internal/definitions"
// ResolveBaseAddresses Resolves all groups to their base Addresses
func ResolveBaseAddresses(allAddresses map[string]definitions.Address, addressNames []string) []definitions.Address {
baseAddresses := []definitions.Address{}
for _, addressName := range addressNames {
address := allAddresses[addressName]
if address.Type == definitions.AddressGroup {
baseAddresses = append(baseAddresses, resolveAddressChildren(allAddresses, address)...)
} else {
baseAddresses = append(baseAddresses, address)
}
}
return baseAddresses
}
func resolveAddressChildren(allAddresses map[string]definitions.Address, a definitions.Address) []definitions.Address {
addressList := []definitions.Address{}
for _, addressName := range *a.Children {
address := allAddresses[addressName]
if address.Type == definitions.AddressGroup {
addressList = append(addressList, resolveAddressChildren(allAddresses, address)...)
} else {
addressList = append(addressList, address)
}
}
return addressList
}

35
internal/util/service.go Normal file
View file

@ -0,0 +1,35 @@
package util
import "nfsense.net/nfsense/internal/definitions"
// ResolveBaseServices Resolves all groups to their base Services
func ResolveBaseServices(allServices map[string]definitions.Service, serviceNames []string) []definitions.Service {
baseServices := []definitions.Service{}
for _, serviceName := range serviceNames {
service := allServices[serviceName]
if service.Type == definitions.ServiceGroup {
baseServices = append(baseServices, resolveServiceChildren(allServices, service)...)
} else {
baseServices = append(baseServices, service)
}
}
return baseServices
}
func resolveServiceChildren(allServices map[string]definitions.Service, s definitions.Service) []definitions.Service {
serviceList := []definitions.Service{}
for _, serviceName := range *s.Children {
service := allServices[serviceName]
if service.Type == definitions.ServiceGroup {
serviceList = append(serviceList, resolveServiceChildren(allServices, service)...)
} else {
serviceList = append(serviceList, service)
}
}
return serviceList
}

21
internal/util/set.go Normal file
View file

@ -0,0 +1,21 @@
package util
func ConvertSliceToSetString(slice []string) string {
if len(slice) == 0 {
return ""
} else if len(slice) == 1 {
return slice[0]
}
res := "{ "
for i := range slice {
res += " " + slice[i]
if i < len(slice)-1 {
res += ","
}
}
res += " }"
return res
}