Browse Source

wireguard: make KeyPairs solid

Signed-off-by: Alejandro Mery <amery@jpi.io>
pull/1/head
Alejandro Mery 10 months ago
parent
commit
30a7bceda3
  1. 14
      pkg/wireguard/keys.go
  2. 33
      pkg/zones/machine_rings.go
  3. 30
      pkg/zones/rings.go

14
pkg/wireguard/keys.go

@ -169,15 +169,13 @@ func (kp *KeyPair) Validate() error {
} }
// NewKeyPair creates a new KeyPair for Wireguard // NewKeyPair creates a new KeyPair for Wireguard
func NewKeyPair() (*KeyPair, error) { func NewKeyPair() (KeyPair, error) {
key, err := NewPrivateKey() var out KeyPair
if err != nil {
return nil, err
}
out := &KeyPair{ key, err := NewPrivateKey()
PrivateKey: key, if err == nil {
PublicKey: key.Public(), out.PrivateKey = key
out.PublicKey = key.Public()
} }
return out, nil return out, nil
} }

33
pkg/zones/machine_rings.go

@ -11,25 +11,24 @@ import (
) )
// GetWireguardKeys reads a wgN.key/wgN.pub files // GetWireguardKeys reads a wgN.key/wgN.pub files
func (m *Machine) GetWireguardKeys(ring int) (*wireguard.KeyPair, error) { func (m *Machine) GetWireguardKeys(ring int) (wireguard.KeyPair, error) {
var ( var (
data []byte data []byte
err error err error
key wireguard.PrivateKey out wireguard.KeyPair
pub wireguard.PublicKey
) )
data, err = m.ReadFile("wg%v.key", ring) data, err = m.ReadFile("wg%v.key", ring)
if err != nil { if err != nil {
// failed to read // failed to read
return nil, err return out, err
} }
key, err = wireguard.PrivateKeyFromBase64(string(data)) out.PrivateKey, err = wireguard.PrivateKeyFromBase64(string(data))
if err != nil { if err != nil {
// bad key // bad key
err = core.Wrapf(err, "wg%v.key", ring) err = core.Wrapf(err, "wg%v.key", ring)
return nil, err return out, err
} }
data, err = m.ReadFile("wg%v.pub", ring) data, err = m.ReadFile("wg%v.pub", ring)
@ -38,27 +37,19 @@ func (m *Machine) GetWireguardKeys(ring int) (*wireguard.KeyPair, error) {
// no wgN.pub is fine // no wgN.pub is fine
case err != nil: case err != nil:
// failed to read // failed to read
return nil, err return out, err
default: default:
// good read // good read
pub, err = wireguard.PublicKeyFromBase64(string(data)) out.PublicKey, err = wireguard.PublicKeyFromBase64(string(data))
if err != nil { if err != nil {
// bad key // bad key
err = core.Wrapf(err, "wg%v.pub", ring) err = core.Wrapf(err, "wg%v.pub", ring)
return nil, err return out, err
} }
} }
kp := &wireguard.KeyPair{ err = out.Validate()
PrivateKey: key, return out, err
PublicKey: pub,
}
if err = kp.Validate(); err != nil {
return nil, err
}
return kp, nil
} }
func (m *Machine) tryReadWireguardKeys(ring int) error { func (m *Machine) tryReadWireguardKeys(ring int) error {
@ -157,7 +148,7 @@ func (m *Machine) applyWireguardInterfaceConfig(ring int, data wireguard.Interfa
ri := &RingInfo{ ri := &RingInfo{
Ring: ring, Ring: ring,
Enabled: true, Enabled: true,
Keys: &wireguard.KeyPair{ Keys: wireguard.KeyPair{
PrivateKey: data.PrivateKey, PrivateKey: data.PrivateKey,
}, },
} }
@ -177,7 +168,7 @@ func (m *Machine) applyWireguardPeerConfig(ring int, pc wireguard.PeerConfig) er
ri := &RingInfo{ ri := &RingInfo{
Ring: ring, Ring: ring,
Enabled: true, Enabled: true,
Keys: &wireguard.KeyPair{ Keys: wireguard.KeyPair{
PublicKey: pc.PublicKey, PublicKey: pc.PublicKey,
}, },
} }

30
pkg/zones/rings.go

@ -23,14 +23,16 @@ const (
// RingInfo contains represents the Wireguard endpoint details // RingInfo contains represents the Wireguard endpoint details
// for a Machine on a particular ring // for a Machine on a particular ring
type RingInfo struct { type RingInfo struct {
Ring int `toml:"ring"` Ring int `toml:"ring"`
Enabled bool `toml:"enabled,omitempty"` Enabled bool `toml:"enabled,omitempty"`
Keys *wireguard.KeyPair `toml:"keys,omitempty"` Keys wireguard.KeyPair `toml:"keys,omitempty"`
} }
// Merge attempts to combine two RingInfo structs // Merge attempts to combine two RingInfo structs
func (ri *RingInfo) Merge(alter *RingInfo) error { func (ri *RingInfo) Merge(alter *RingInfo) error {
switch { switch {
case alter == nil:
return nil
case ri.Ring != alter.Ring: case ri.Ring != alter.Ring:
// different ring // different ring
return fmt.Errorf("invalid %s: %v ≠ %v", "ring", ri.Ring, alter.Ring) return fmt.Errorf("invalid %s: %v ≠ %v", "ring", ri.Ring, alter.Ring)
@ -51,27 +53,19 @@ func (ri *RingInfo) unsafeMerge(alter *RingInfo) error {
ri.Enabled = true ri.Enabled = true
} }
switch { // fill the gaps on our keypair
case ri.Keys == nil: if ri.Keys.PrivateKey.IsZero() {
// assign keypair ri.Keys.PrivateKey = alter.Keys.PrivateKey
ri.Keys = alter.Keys }
case alter.Keys != nil: if ri.Keys.PublicKey.IsZero() {
// fill the gaps on our keypair ri.Keys.PublicKey = alter.Keys.PublicKey
if ri.Keys.PrivateKey.IsZero() {
ri.Keys.PrivateKey = alter.Keys.PrivateKey
}
if ri.Keys.PublicKey.IsZero() {
ri.Keys.PublicKey = alter.Keys.PublicKey
}
} }
return nil return nil
} }
func canMergeKeyPairs(p1, p2 *wireguard.KeyPair) bool { func canMergeKeyPairs(p1, p2 wireguard.KeyPair) bool {
switch { switch {
case p1 == nil || p2 == nil:
return true
case !p1.PrivateKey.IsZero() && !p2.PrivateKey.IsZero() && !p1.PrivateKey.Equal(p2.PrivateKey): case !p1.PrivateKey.IsZero() && !p2.PrivateKey.IsZero() && !p1.PrivateKey.Equal(p2.PrivateKey):
return false return false
case !p1.PublicKey.IsZero() && !p2.PublicKey.IsZero() && !p1.PublicKey.Equal(p2.PublicKey): case !p1.PublicKey.IsZero() && !p2.PublicKey.IsZero() && !p1.PublicKey.Equal(p2.PublicKey):

Loading…
Cancel
Save