diff --git a/pkg/wireguard/keys.go b/pkg/wireguard/keys.go index c18cf3c..6c29284 100644 --- a/pkg/wireguard/keys.go +++ b/pkg/wireguard/keys.go @@ -27,60 +27,71 @@ var ( type ( // PrivateKey is a binary Wireguard Private Key - PrivateKey []byte + PrivateKey [PrivateKeySize]byte // PublicKey is a binary Wireguard Public Key - PublicKey []byte + PublicKey [PublicKeySize]byte ) func (key PrivateKey) String() string { - return encodeKey(key) + switch { + case key.IsZero(): + return "" + default: + return base64.StdEncoding.EncodeToString(key[:]) + } } func (pub PublicKey) String() string { - return encodeKey(pub) + switch { + case pub.IsZero(): + return "" + default: + return base64.StdEncoding.EncodeToString(pub[:]) + } } // IsZero tells if the key hasn't been set func (key PrivateKey) IsZero() bool { - return len(key) == 0 + var zero PrivateKey + return key.Equal(zero) } // IsZero tells if the key hasn't been set func (pub PublicKey) IsZero() bool { - return len(pub) == 0 + var zero PublicKey + return pub.Equal(zero) } // Equal checks if two private keys are identical func (key PrivateKey) Equal(alter PrivateKey) bool { - return bytes.Equal(key, alter) + return bytes.Equal(key[:], alter[:]) } // Equal checks if two public keys are identical func (pub PublicKey) Equal(alter PublicKey) bool { - return bytes.Equal(pub, alter) + return bytes.Equal(pub[:], alter[:]) } // PrivateKeyFromBase64 decodes a base64-based string into // a [PrivateKey] func PrivateKeyFromBase64(data string) (PrivateKey, error) { b, err := decodeKey(data, PrivateKeySize) - return b, err + if err != nil { + var zero PrivateKey + return zero, err + } + return *(*[PrivateKeySize]byte)(b), nil } // PublicKeyFromBase64 decodes a base64-based string into // a [PublicKey] func PublicKeyFromBase64(data string) (PublicKey, error) { b, err := decodeKey(data, PublicKeySize) - return b, err -} - -func encodeKey(b []byte) string { - switch { - case len(b) == 0: - return "" - default: - return base64.StdEncoding.EncodeToString(b) + if err != nil { + var zero PublicKey + return zero, err } + return *(*[PublicKeySize]byte)(b), nil } func decodeKey(data string, size int) ([]byte, error) { @@ -102,27 +113,27 @@ func NewPrivateKey() (PrivateKey, error) { _, err := rand.Read(s[:]) if err != nil { - return []byte{}, err + var zero PrivateKey + return zero, err } // apply same clamping as wireguard-go/device/noise-helpers.go s[0] &= 0xf8 s[31] = (s[31] & 0x7f) | 0x40 - return s[:], nil + return s, nil } // Public generates the corresponding PublicKey func (key PrivateKey) Public() PublicKey { - if len(key) != PrivateKeySize { - return []byte{} - } + var pub PublicKey + if !key.IsZero() { + in := (*[PrivateKeySize]byte)(&key) + out := (*[PublicKeySize]byte)(&pub) - out := [PublicKeySize]byte{} - in := (*[PrivateKeySize]byte)(key) - - curve25519.ScalarBaseMult(&out, in) - return out[:] + curve25519.ScalarBaseMult(out, in) + } + return pub } // KeyPair holds a Key pair