package wireguard

import (
	"bytes"
	"errors"
	"fmt"
	"io"
	"net/netip"
	"strconv"
	"strings"
	"text/template"

	"darvaza.org/core"
	"gopkg.in/gcfg.v1"
)

var configTemplate = template.Must(template.New("config").Funcs(template.FuncMap{
	"PrefixJoin": func(a []netip.Prefix, sep string) string {
		s := make([]string, len(a))
		for i, p := range a {
			s[i] = p.String()
		}
		return strings.Join(s, sep)
	},
}).Parse(`[Interface]
{{if .Interface.Name}}# Name: {{.Interface.Name}}
{{end -}}
Address = {{.Interface.Address}}
PrivateKey = {{.Interface.PrivateKey}}
ListenPort = {{.Interface.ListenPort}}
{{- range .Peer }}

[Peer]
{{if .Name}}# Name: {{.Name}}
{{end -}}
PublicKey = {{.PublicKey}}
Endpoint = {{.Endpoint}}
AllowedIPs = {{ PrefixJoin .AllowedIPs ", "}}
{{- end }}
`))

// Config represents a wgN.conf file
type Config struct {
	Interface InterfaceConfig
	Peer      []PeerConfig
}

// GetAddress is a shortcut to the interface's address
func (f *Config) GetAddress() netip.Addr {
	return f.Interface.Address
}

// Peers tells how many peers are described
func (f *Config) Peers() int {
	return len(f.Peer)
}

// WriteTo writes a Wireguard [Config] onto the provided [io.Writer]
func (f *Config) WriteTo(w io.Writer) (int64, error) {
	var buf bytes.Buffer

	if err := configTemplate.Execute(&buf, f); err != nil {
		return 0, err
	}

	return buf.WriteTo(w)
}

// InterfaceConfig represents the [Interface] section
type InterfaceConfig struct {
	Name       string
	Address    netip.Addr
	PrivateKey PrivateKey
	ListenPort uint16
}

// PeerConfig represents a [Peer] section
type PeerConfig struct {
	Name       string
	PublicKey  PublicKey
	Endpoint   EndpointAddress
	AllowedIPs []netip.Prefix
}

// EndpointAddress is a host:port pair to reach the Peer
type EndpointAddress struct {
	Host string
	Port uint16
}

// Name returns the first part of a hostname
func (ep EndpointAddress) Name() string {
	before, _, _ := strings.Cut(ep.Host, ".")
	return before
}

func (ep EndpointAddress) String() string {
	switch {
	case ep.Host == "":
		return ""
	case ep.Port == 0:
		return ep.Host
	case !strings.ContainsRune(ep.Host, ':'):
		return fmt.Sprintf("%s:%v", ep.Host, ep.Port)
	default:
		return fmt.Sprintf("[%s]:%v", ep.Host, ep.Port)
	}
}

// FromString sets the EndpointAddress from a given "[host]:port"
func (ep *EndpointAddress) FromString(s string) error {
	host, port, err := core.SplitHostPort(s)
	if err != nil {
		return err
	}

	ep.Host = host

	switch {
	case port != "":
		n, _ := strconv.ParseUint(port, 10, 16)
		ep.Port = uint16(n)
	default:
		ep.Port = 0
	}

	return nil
}

type intermediateConfig struct {
	Interface interfaceConfig
	Peer      peersConfig
}

func (v *intermediateConfig) Export() (*Config, error) {
	var out Config
	var err error

	// Interface
	out.Interface, err = v.Interface.Export()
	if err != nil {
		return nil, err
	}

	// Peers
	peers, ok := v.PeersCount()
	if !ok {
		return nil, errors.New("inconsistent Peer data")
	}

	for i := 0; i < peers; i++ {
		p, err := v.ExportPeer(i)
		if err != nil {
			err = core.Wrap(err, "Peer[%v]:", i)
			return nil, err
		}

		out.Peer = append(out.Peer, p)
	}

	return &out, nil
}

type interfaceConfig struct {
	Address    netip.Addr
	PrivateKey string
	ListenPort uint16
}

func (p interfaceConfig) Export() (InterfaceConfig, error) {
	var err error

	out := InterfaceConfig{
		Address:    p.Address,
		ListenPort: p.ListenPort,
	}

	out.PrivateKey, err = PrivateKeyFromBase64(p.PrivateKey)
	if err != nil {
		err = core.Wrap(err, "PrivateKey")
		return InterfaceConfig{}, err
	}

	return out, nil
}

type peersConfig struct {
	PublicKey  []string
	Endpoint   []string
	AllowedIPs []string
}

func (v *intermediateConfig) ExportPeer(i int) (PeerConfig, error) {
	var out PeerConfig

	// Endpoint
	s := v.Peer.Endpoint[i]
	err := out.Endpoint.FromString(s)
	if err != nil {
		err = core.Wrap(err, "Endpoint")
		return out, err
	}

	// PublicKey
	out.PublicKey, err = PublicKeyFromBase64(v.Peer.PublicKey[i])
	if err != nil {
		err = core.Wrap(err, "PublicKey")
		return out, err
	}

	// AllowedIPs
	s = v.Peer.AllowedIPs[i]
	out.AllowedIPs, err = parseAllowedIPs(s)
	if err != nil {
		err = core.Wrap(err, "AllowedIPs")
		return out, err
	}

	return out, nil
}

func parseAllowedIPs(data string) ([]netip.Prefix, error) {
	var out []netip.Prefix

	for _, s := range strings.Split(data, ",") {
		s = strings.TrimSpace(s)
		p, err := netip.ParsePrefix(s)
		if err != nil {
			return out, err
		}

		out = append(out, p)
	}

	return out, nil
}

func (v *intermediateConfig) PeersCount() (int, bool) {
	c0 := len(v.Peer.Endpoint)
	c1 := len(v.Peer.PublicKey)
	c2 := len(v.Peer.AllowedIPs)

	if c0 != c1 || c1 != c2 {
		return 0, false
	}

	return c0, true
}

// NewConfigFromReader parses a wgN.conf file
func NewConfigFromReader(r io.Reader) (*Config, error) {
	temp := &intermediateConfig{}

	if err := gcfg.ReadInto(temp, r); err != nil {
		return nil, err
	}

	return temp.Export()
}