diff --git a/pkg/wireguard/keys.go b/pkg/wireguard/keys.go index 100eb02..c18cf3c 100644 --- a/pkg/wireguard/keys.go +++ b/pkg/wireguard/keys.go @@ -19,6 +19,10 @@ const ( var ( // ErrInvalidKeySize indicates the key size is wrong ErrInvalidKeySize = errors.New("invalid key size") + // ErrInvalidPrivateKey indicates the private key is invalid + ErrInvalidPrivateKey = errors.New("invalid private key") + // ErrInvalidPublicKey indicates the public key is invalid + ErrInvalidPublicKey = errors.New("invalid public key") ) type ( @@ -127,6 +131,32 @@ type KeyPair struct { PublicKey PublicKey } +// Validate checks the PublicKey matches the PrivateKey, +// and sets the PublicKey if missing +func (kp *KeyPair) Validate() error { + keyLen := len(kp.PrivateKey) + pubLen := len(kp.PublicKey) + + switch { + case keyLen != PrivateKeySize: + // bad private key + return ErrInvalidPrivateKey + case pubLen == 0: + // no public key, set it + kp.PublicKey = kp.PrivateKey.Public() + return nil + case pubLen != PublicKeySize: + // bad public key + return ErrInvalidPublicKey + case !kp.PrivateKey.Public().Equal(kp.PublicKey): + // wrong public key + return ErrInvalidPublicKey + default: + // correct public key + return nil + } +} + // NewKeyPair creates a new KeyPair for Wireguard func NewKeyPair() (*KeyPair, error) { key, err := NewPrivateKey()