Added flags for passing client cert and client private key as file contents instead of paths

This commit is contained in:
Daniel Del Rio Figueira 2024-12-04 00:39:11 +01:00
parent 0273cee2ba
commit 78ed21f62b
No known key found for this signature in database
GPG key ID: 16C55CB50D1B770D
2 changed files with 34 additions and 13 deletions

View file

@ -2,7 +2,6 @@ package cmd
import (
"fmt"
"io/ioutil"
"os"
"path/filepath"
"time"
@ -61,8 +60,10 @@ func init() {
rootCmd.PersistentFlags().Duration("mfaDelay", time.Second*10, "Delay between MFA Attempts, only used in noninteractive modes")
rootCmd.PersistentFlags().Bool("tlsSkipVerify", false, "Allow servers with self-signed certificates")
rootCmd.PersistentFlags().String("tlsClientPrivateKeyFile", "", "Client private key for mtls")
rootCmd.PersistentFlags().String("tlsClientCertFile", "", "Client certificate for mtls")
rootCmd.PersistentFlags().String("tlsClientPrivateKeyFile", "", "Client private key path for mtls")
rootCmd.PersistentFlags().String("tlsClientCertFile", "", "Client certificate path for mtls")
rootCmd.PersistentFlags().String("tlsClientPrivateKey", "", "Client private key for mtls")
rootCmd.PersistentFlags().String("tlsClientCert", "", "Client certificate for mtls")
viper.BindPFlag("debug", rootCmd.PersistentFlags().Lookup("debug"))
viper.BindPFlag("timeout", rootCmd.PersistentFlags().Lookup("timeout"))
@ -82,6 +83,18 @@ func init() {
viper.BindPFlag("tlsClientPrivateKey", rootCmd.PersistentFlags().Lookup("tlsClientPrivateKey"))
}
func fileToContent(file, contentFlag string) {
if viper.GetBool("debug") {
fmt.Fprintln(os.Stderr, "Loading file:", file)
}
content, err := os.ReadFile(file)
if err != nil {
fmt.Fprintln(os.Stderr, "Error Loading File: ", err)
os.Exit(1)
}
viper.Set(contentFlag, string(content))
}
// initConfig reads in config file and ENV variables if set.
func initConfig() {
if cfgFile != "" {
@ -115,18 +128,26 @@ func initConfig() {
// Read in Private Key from File if userprivatekeyfile is set
userprivatekeyfile, err := rootCmd.PersistentFlags().GetString("userPrivateKeyFile")
if err == nil && userprivatekeyfile != "" {
if viper.GetBool("debug") {
fmt.Fprintln(os.Stderr, "Loading Private Key from File:", userprivatekeyfile)
}
content, err := ioutil.ReadFile(userprivatekeyfile)
if err != nil {
fmt.Fprintln(os.Stderr, "Error Loading Private Key from File: ", err)
os.Exit(1)
}
viper.Set("userprivatekey", string(content))
fileToContent(userprivatekeyfile, "userPrivateKey")
} else if err != nil && viper.GetBool("debug") {
fmt.Fprintln(os.Stderr, "Getting Private Key File Flag:", err)
}
// Read in Client Certificate Private Key from File if tlsClientPrivateKeyFile is set
tlsclientprivatekeyfile, err := rootCmd.PersistentFlags().GetString("tlsClientPrivateKeyFile")
if err == nil && tlsclientprivatekeyfile != "" {
fileToContent(tlsclientprivatekeyfile, "tlsClientPrivateKey")
} else if err != nil && viper.GetBool("debug") {
fmt.Fprintln(os.Stderr, "Getting Client Certificate Private key File Flag:", err)
}
// Read in Client Certificate from File if tlsClientCertFile is set
tlsclientcertfile, err := rootCmd.PersistentFlags().GetString("tlsClientCertFile")
if err == nil && tlsclientcertfile != "" {
fileToContent(tlsclientcertfile, "tlsClientCert")
} else if err != nil && viper.GetBool("debug") {
fmt.Fprintln(os.Stderr, "Getting Client Certificate File Flag:", err)
}
}
func SetVersionInfo(version, commit, date string, dirty bool) {

View file

@ -22,7 +22,7 @@ func GetClientCertificate() (tls.Certificate, error) {
if !certExists && keyExists {
return tls.Certificate{}, fmt.Errorf("Client TLS cert is empty, but client TLS private key was set.")
}
return tls.LoadX509KeyPair(cert, key)
return tls.X509KeyPair([]byte(cert), []byte(key))
}
func GetHttpClient() (*http.Client, error) {