package wireguard import ( "bytes" "crypto/rand" "encoding/base64" "errors" "golang.org/x/crypto/curve25519" ) const ( // PrivateKeySize is the length in bytes of a Wireguard Private Key PrivateKeySize = 32 // PublicKeySize is the length in bytes of a Wireguard Public Key PublicKeySize = 32 ) var ( // ErrInvalidKeySize indicates the key size is wrong ErrInvalidKeySize = errors.New("invalid key size") // ErrInvalidPrivateKey indicates the private key is invalid ErrInvalidPrivateKey = errors.New("invalid private key") // ErrInvalidPublicKey indicates the public key is invalid ErrInvalidPublicKey = errors.New("invalid public key") ) type ( // PrivateKey is a binary Wireguard Private Key PrivateKey [PrivateKeySize]byte // PublicKey is a binary Wireguard Public Key PublicKey [PublicKeySize]byte ) func (key PrivateKey) String() string { switch { case key.IsZero(): return "" default: return base64.StdEncoding.EncodeToString(key[:]) } } func (pub PublicKey) String() string { switch { case pub.IsZero(): return "" default: return base64.StdEncoding.EncodeToString(pub[:]) } } // IsZero tells if the key hasn't been set func (key PrivateKey) IsZero() bool { var zero PrivateKey return key.Equal(zero) } // IsZero tells if the key hasn't been set func (pub PublicKey) IsZero() bool { var zero PublicKey return pub.Equal(zero) } // Equal checks if two private keys are identical func (key PrivateKey) Equal(alter PrivateKey) bool { return bytes.Equal(key[:], alter[:]) } // Equal checks if two public keys are identical func (pub PublicKey) Equal(alter PublicKey) bool { return bytes.Equal(pub[:], alter[:]) } // PrivateKeyFromBase64 decodes a base64-based string into // a [PrivateKey] func PrivateKeyFromBase64(data string) (PrivateKey, error) { b, err := decodeKey(data, PrivateKeySize) if err != nil { var zero PrivateKey return zero, err } return *(*[PrivateKeySize]byte)(b), nil } // PublicKeyFromBase64 decodes a base64-based string into // a [PublicKey] func PublicKeyFromBase64(data string) (PublicKey, error) { b, err := decodeKey(data, PublicKeySize) if err != nil { var zero PublicKey return zero, err } return *(*[PublicKeySize]byte)(b), nil } func decodeKey(data string, size int) ([]byte, error) { b, err := base64.StdEncoding.DecodeString(data) switch { case err != nil: return []byte{}, err case len(b) != size: err = ErrInvalidKeySize return []byte{}, err default: return b, nil } } // NewPrivateKey creates a new PrivateKey func NewPrivateKey() (PrivateKey, error) { var s [PrivateKeySize]byte _, err := rand.Read(s[:]) if err != nil { var zero PrivateKey return zero, err } // apply same clamping as wireguard-go/device/noise-helpers.go s[0] &= 0xf8 s[31] = (s[31] & 0x7f) | 0x40 return s, nil } // Public generates the corresponding PublicKey func (key PrivateKey) Public() PublicKey { var pub PublicKey if !key.IsZero() { in := (*[PrivateKeySize]byte)(&key) out := (*[PublicKeySize]byte)(&pub) curve25519.ScalarBaseMult(out, in) } return pub } // KeyPair holds a Key pair type KeyPair struct { PrivateKey PrivateKey PublicKey PublicKey } // Validate checks the PublicKey matches the PrivateKey, // and sets the PublicKey if missing func (kp *KeyPair) Validate() error { keyLen := len(kp.PrivateKey) pubLen := len(kp.PublicKey) switch { case keyLen != PrivateKeySize: // bad private key return ErrInvalidPrivateKey case pubLen == 0: // no public key, set it kp.PublicKey = kp.PrivateKey.Public() return nil case pubLen != PublicKeySize: // bad public key return ErrInvalidPublicKey case !kp.PrivateKey.Public().Equal(kp.PublicKey): // wrong public key return ErrInvalidPublicKey default: // correct public key return nil } } // NewKeyPair creates a new KeyPair for Wireguard func NewKeyPair() (*KeyPair, error) { key, err := NewPrivateKey() if err != nil { return nil, err } out := &KeyPair{ PrivateKey: key, PublicKey: key.Public(), } return out, nil }