From 30a7bceda300052d60b81a16292da7b6d428328e Mon Sep 17 00:00:00 2001 From: Alejandro Mery Date: Thu, 24 Aug 2023 19:49:59 +0000 Subject: [PATCH] wireguard: make KeyPairs solid Signed-off-by: Alejandro Mery --- pkg/wireguard/keys.go | 14 ++++++-------- pkg/zones/machine_rings.go | 33 ++++++++++++--------------------- pkg/zones/rings.go | 30 ++++++++++++------------------ 3 files changed, 30 insertions(+), 47 deletions(-) diff --git a/pkg/wireguard/keys.go b/pkg/wireguard/keys.go index 6c29284..5e01599 100644 --- a/pkg/wireguard/keys.go +++ b/pkg/wireguard/keys.go @@ -169,15 +169,13 @@ func (kp *KeyPair) Validate() error { } // NewKeyPair creates a new KeyPair for Wireguard -func NewKeyPair() (*KeyPair, error) { - key, err := NewPrivateKey() - if err != nil { - return nil, err - } +func NewKeyPair() (KeyPair, error) { + var out KeyPair - out := &KeyPair{ - PrivateKey: key, - PublicKey: key.Public(), + key, err := NewPrivateKey() + if err == nil { + out.PrivateKey = key + out.PublicKey = key.Public() } return out, nil } diff --git a/pkg/zones/machine_rings.go b/pkg/zones/machine_rings.go index 8085de7..350091c 100644 --- a/pkg/zones/machine_rings.go +++ b/pkg/zones/machine_rings.go @@ -11,25 +11,24 @@ import ( ) // GetWireguardKeys reads a wgN.key/wgN.pub files -func (m *Machine) GetWireguardKeys(ring int) (*wireguard.KeyPair, error) { +func (m *Machine) GetWireguardKeys(ring int) (wireguard.KeyPair, error) { var ( data []byte err error - key wireguard.PrivateKey - pub wireguard.PublicKey + out wireguard.KeyPair ) data, err = m.ReadFile("wg%v.key", ring) if err != nil { // failed to read - return nil, err + return out, err } - key, err = wireguard.PrivateKeyFromBase64(string(data)) + out.PrivateKey, err = wireguard.PrivateKeyFromBase64(string(data)) if err != nil { // bad key err = core.Wrapf(err, "wg%v.key", ring) - return nil, err + return out, err } data, err = m.ReadFile("wg%v.pub", ring) @@ -38,27 +37,19 @@ func (m *Machine) GetWireguardKeys(ring int) (*wireguard.KeyPair, error) { // no wgN.pub is fine case err != nil: // failed to read - return nil, err + return out, err default: // good read - pub, err = wireguard.PublicKeyFromBase64(string(data)) + out.PublicKey, err = wireguard.PublicKeyFromBase64(string(data)) if err != nil { // bad key err = core.Wrapf(err, "wg%v.pub", ring) - return nil, err + return out, err } } - kp := &wireguard.KeyPair{ - PrivateKey: key, - PublicKey: pub, - } - - if err = kp.Validate(); err != nil { - return nil, err - } - - return kp, nil + err = out.Validate() + return out, err } func (m *Machine) tryReadWireguardKeys(ring int) error { @@ -157,7 +148,7 @@ func (m *Machine) applyWireguardInterfaceConfig(ring int, data wireguard.Interfa ri := &RingInfo{ Ring: ring, Enabled: true, - Keys: &wireguard.KeyPair{ + Keys: wireguard.KeyPair{ PrivateKey: data.PrivateKey, }, } @@ -177,7 +168,7 @@ func (m *Machine) applyWireguardPeerConfig(ring int, pc wireguard.PeerConfig) er ri := &RingInfo{ Ring: ring, Enabled: true, - Keys: &wireguard.KeyPair{ + Keys: wireguard.KeyPair{ PublicKey: pc.PublicKey, }, } diff --git a/pkg/zones/rings.go b/pkg/zones/rings.go index ac25bd1..40108f1 100644 --- a/pkg/zones/rings.go +++ b/pkg/zones/rings.go @@ -23,14 +23,16 @@ const ( // RingInfo contains represents the Wireguard endpoint details // for a Machine on a particular ring type RingInfo struct { - Ring int `toml:"ring"` - Enabled bool `toml:"enabled,omitempty"` - Keys *wireguard.KeyPair `toml:"keys,omitempty"` + Ring int `toml:"ring"` + Enabled bool `toml:"enabled,omitempty"` + Keys wireguard.KeyPair `toml:"keys,omitempty"` } // Merge attempts to combine two RingInfo structs func (ri *RingInfo) Merge(alter *RingInfo) error { switch { + case alter == nil: + return nil case ri.Ring != alter.Ring: // different ring return fmt.Errorf("invalid %s: %v ≠ %v", "ring", ri.Ring, alter.Ring) @@ -51,27 +53,19 @@ func (ri *RingInfo) unsafeMerge(alter *RingInfo) error { ri.Enabled = true } - switch { - case ri.Keys == nil: - // assign keypair - ri.Keys = alter.Keys - case alter.Keys != nil: - // fill the gaps on our keypair - if ri.Keys.PrivateKey.IsZero() { - ri.Keys.PrivateKey = alter.Keys.PrivateKey - } - if ri.Keys.PublicKey.IsZero() { - ri.Keys.PublicKey = alter.Keys.PublicKey - } + // fill the gaps on our keypair + if ri.Keys.PrivateKey.IsZero() { + ri.Keys.PrivateKey = alter.Keys.PrivateKey + } + if ri.Keys.PublicKey.IsZero() { + ri.Keys.PublicKey = alter.Keys.PublicKey } return nil } -func canMergeKeyPairs(p1, p2 *wireguard.KeyPair) bool { +func canMergeKeyPairs(p1, p2 wireguard.KeyPair) bool { switch { - case p1 == nil || p2 == nil: - return true case !p1.PrivateKey.IsZero() && !p2.PrivateKey.IsZero() && !p1.PrivateKey.Equal(p2.PrivateKey): return false case !p1.PublicKey.IsZero() && !p2.PublicKey.IsZero() && !p1.PublicKey.Equal(p2.PublicKey):