package wireguard

import (
	"bytes"
	"crypto/rand"
	"encoding/base64"
	"errors"
	"fmt"

	"golang.org/x/crypto/curve25519"
)

const (
	// PrivateKeySize is the length in bytes of a Wireguard Private Key
	PrivateKeySize = 32
	// PublicKeySize is the length in bytes of a Wireguard Public Key
	PublicKeySize = 32
)

var (
	// ErrInvalidKeySize indicates the key size is wrong
	ErrInvalidKeySize = errors.New("invalid key size")
	// ErrInvalidPrivateKey indicates the private key is invalid
	ErrInvalidPrivateKey = errors.New("invalid private key")
	// ErrInvalidPublicKey indicates the public key is invalid
	ErrInvalidPublicKey = errors.New("invalid public key")
)

type (
	// PrivateKey is a binary Wireguard Private Key
	PrivateKey [PrivateKeySize]byte
	// PublicKey is a binary Wireguard Public Key
	PublicKey [PublicKeySize]byte
)

func (key PrivateKey) String() string {
	switch {
	case key.IsZero():
		return ""
	default:
		return base64.StdEncoding.EncodeToString(key[:])
	}
}

func (pub PublicKey) String() string {
	switch {
	case pub.IsZero():
		return ""
	default:
		return base64.StdEncoding.EncodeToString(pub[:])
	}
}

// MarshalJSON encodes the key for JSON, omitting empty.
func (key PrivateKey) MarshalJSON() ([]byte, error) {
	return encodeKeyJSON(key.String())
}

// MarshalJSON encodes the key for JSON, omitting empty.
func (pub PublicKey) MarshalJSON() ([]byte, error) {
	return encodeKeyJSON(pub.String())
}

func encodeKeyJSON(s string) ([]byte, error) {
	var out []byte
	if s != "" {
		out = []byte(fmt.Sprintf("%q", s))
	}

	return out, nil
}

// MarshalYAML encodes the key for YAML, omitting empty.
func (key PrivateKey) MarshalYAML() (any, error) {
	return encodeKeyYAML(key.String())
}

// MarshalYAML encodes the key for YAML, omitting empty.
func (pub PublicKey) MarshalYAML() (any, error) {
	return encodeKeyYAML(pub.String())
}

func encodeKeyYAML(s string) (any, error) {
	if s == "" {
		return nil, nil
	}

	return s, nil
}

// IsZero tells if the key hasn't been set
func (key PrivateKey) IsZero() bool {
	var zero PrivateKey
	return key.Equal(zero)
}

// IsZero tells if the key hasn't been set
func (pub PublicKey) IsZero() bool {
	var zero PublicKey
	return pub.Equal(zero)
}

// Equal checks if two private keys are identical
func (key PrivateKey) Equal(alter PrivateKey) bool {
	return bytes.Equal(key[:], alter[:])
}

// Equal checks if two public keys are identical
func (pub PublicKey) Equal(alter PublicKey) bool {
	return bytes.Equal(pub[:], alter[:])
}

// PrivateKeyFromBase64 decodes a base64-based string into
// a [PrivateKey]
func PrivateKeyFromBase64(data string) (PrivateKey, error) {
	b, err := decodeKey(data, PrivateKeySize)
	if err != nil {
		var zero PrivateKey
		return zero, err
	}
	return *(*[PrivateKeySize]byte)(b), nil
}

// PublicKeyFromBase64 decodes a base64-based string into
// a [PublicKey]
func PublicKeyFromBase64(data string) (PublicKey, error) {
	b, err := decodeKey(data, PublicKeySize)
	if err != nil {
		var zero PublicKey
		return zero, err
	}
	return *(*[PublicKeySize]byte)(b), nil
}

func decodeKey(data string, size int) ([]byte, error) {
	b, err := base64.StdEncoding.DecodeString(data)
	switch {
	case err != nil:
		return []byte{}, err
	case len(b) != size:
		err = ErrInvalidKeySize
		return []byte{}, err
	default:
		return b, nil
	}
}

// NewPrivateKey creates a new PrivateKey
func NewPrivateKey() (PrivateKey, error) {
	var s [PrivateKeySize]byte

	_, err := rand.Read(s[:])
	if err != nil {
		var zero PrivateKey
		return zero, err
	}

	// apply same clamping as wireguard-go/device/noise-helpers.go
	s[0] &= 0xf8
	s[31] = (s[31] & 0x7f) | 0x40

	return s, nil
}

// Public generates the corresponding PublicKey
func (key PrivateKey) Public() PublicKey {
	var pub PublicKey
	if !key.IsZero() {
		in := (*[PrivateKeySize]byte)(&key)
		out := (*[PublicKeySize]byte)(&pub)

		curve25519.ScalarBaseMult(out, in)
	}
	return pub
}

// KeyPair holds a Key pair
type KeyPair struct {
	PrivateKey PrivateKey
	PublicKey  PublicKey
}

// Validate checks the PublicKey matches the PrivateKey,
// and sets the PublicKey if missing
func (kp *KeyPair) Validate() error {
	switch {
	case kp.PrivateKey.IsZero():
		// no private key
		return ErrInvalidPrivateKey
	case kp.PublicKey.IsZero():
		// no public key, set it
		kp.PublicKey = kp.PrivateKey.Public()
		return nil
	case !kp.PrivateKey.Public().Equal(kp.PublicKey):
		// wrong public key
		return ErrInvalidPublicKey
	default:
		// correct public key
		return nil
	}
}

// NewKeyPair creates a new KeyPair for Wireguard
func NewKeyPair() (KeyPair, error) {
	var out KeyPair

	key, err := NewPrivateKey()
	if err == nil {
		out.PrivateKey = key
		out.PublicKey = key.Public()
	}
	return out, nil
}