package wireguard import ( "bytes" "crypto/rand" "encoding/base64" "errors" "fmt" "golang.org/x/crypto/curve25519" ) const ( // PrivateKeySize is the length in bytes of a Wireguard Private Key PrivateKeySize = 32 // PublicKeySize is the length in bytes of a Wireguard Public Key PublicKeySize = 32 ) var ( // ErrInvalidKeySize indicates the key size is wrong ErrInvalidKeySize = errors.New("invalid key size") // ErrInvalidPrivateKey indicates the private key is invalid ErrInvalidPrivateKey = errors.New("invalid private key") // ErrInvalidPublicKey indicates the public key is invalid ErrInvalidPublicKey = errors.New("invalid public key") ) type ( // PrivateKey is a binary Wireguard Private Key PrivateKey [PrivateKeySize]byte // PublicKey is a binary Wireguard Public Key PublicKey [PublicKeySize]byte ) func (key PrivateKey) String() string { switch { case key.IsZero(): return "" default: return base64.StdEncoding.EncodeToString(key[:]) } } func (pub PublicKey) String() string { switch { case pub.IsZero(): return "" default: return base64.StdEncoding.EncodeToString(pub[:]) } } // UnmarshalText loads the value from base64 func (key *PrivateKey) UnmarshalText(b []byte) error { v, err := PrivateKeyFromBase64(string(b)) if err != nil { return err } *key = v return nil } // UnmarshalText loads the value from base64 func (pub *PublicKey) UnmarshalText(b []byte) error { v, err := PublicKeyFromBase64(string(b)) if err != nil { return err } *pub = v return nil } // MarshalJSON encodes the key for JSON, omitting empty. func (key PrivateKey) MarshalJSON() ([]byte, error) { return encodeKeyJSON(key.String()) } // MarshalJSON encodes the key for JSON, omitting empty. func (pub PublicKey) MarshalJSON() ([]byte, error) { return encodeKeyJSON(pub.String()) } func encodeKeyJSON(s string) ([]byte, error) { var out []byte if s != "" { out = []byte(fmt.Sprintf("%q", s)) } return out, nil } // MarshalYAML encodes the key for YAML, omitting empty. func (key PrivateKey) MarshalYAML() (any, error) { return encodeKeyYAML(key.String()) } // MarshalYAML encodes the key for YAML, omitting empty. func (pub PublicKey) MarshalYAML() (any, error) { return encodeKeyYAML(pub.String()) } func encodeKeyYAML(s string) (any, error) { if s == "" { return nil, nil } return s, nil } // IsZero tells if the key hasn't been set func (key PrivateKey) IsZero() bool { var zero PrivateKey return key.Equal(zero) } // IsZero tells if the key hasn't been set func (pub PublicKey) IsZero() bool { var zero PublicKey return pub.Equal(zero) } // Equal checks if two private keys are identical func (key PrivateKey) Equal(alter PrivateKey) bool { return bytes.Equal(key[:], alter[:]) } // Equal checks if two public keys are identical func (pub PublicKey) Equal(alter PublicKey) bool { return bytes.Equal(pub[:], alter[:]) } // PrivateKeyFromBase64 decodes a base64-based string into // a [PrivateKey] func PrivateKeyFromBase64(data string) (PrivateKey, error) { b, err := decodeKey(data, PrivateKeySize) if err != nil { var zero PrivateKey return zero, err } return *(*[PrivateKeySize]byte)(b), nil } // PublicKeyFromBase64 decodes a base64-based string into // a [PublicKey] func PublicKeyFromBase64(data string) (PublicKey, error) { b, err := decodeKey(data, PublicKeySize) if err != nil { var zero PublicKey return zero, err } return *(*[PublicKeySize]byte)(b), nil } func decodeKey(data string, size int) ([]byte, error) { b, err := base64.StdEncoding.DecodeString(data) switch { case err != nil: return []byte{}, err case len(b) != size: err = ErrInvalidKeySize return []byte{}, err default: return b, nil } } // NewPrivateKey creates a new PrivateKey func NewPrivateKey() (PrivateKey, error) { var s [PrivateKeySize]byte _, err := rand.Read(s[:]) if err != nil { var zero PrivateKey return zero, err } // apply same clamping as wireguard-go/device/noise-helpers.go s[0] &= 0xf8 s[31] = (s[31] & 0x7f) | 0x40 return s, nil } // Public generates the corresponding PublicKey func (key PrivateKey) Public() PublicKey { var pub PublicKey if !key.IsZero() { in := (*[PrivateKeySize]byte)(&key) out := (*[PublicKeySize]byte)(&pub) curve25519.ScalarBaseMult(out, in) } return pub } // KeyPair holds a Key pair type KeyPair struct { PrivateKey PrivateKey PublicKey PublicKey } // Validate checks the PublicKey matches the PrivateKey, // and sets the PublicKey if missing func (kp *KeyPair) Validate() error { switch { case kp.PrivateKey.IsZero(): // no private key return ErrInvalidPrivateKey case kp.PublicKey.IsZero(): // no public key, set it kp.PublicKey = kp.PrivateKey.Public() return nil case !kp.PrivateKey.Public().Equal(kp.PublicKey): // wrong public key return ErrInvalidPublicKey default: // correct public key return nil } } // NewKeyPair creates a new KeyPair for Wireguard func NewKeyPair() (KeyPair, error) { var out KeyPair key, err := NewPrivateKey() if err == nil { out.PrivateKey = key out.PublicKey = key.Public() } return out, nil }