package server

import (
	"crypto/elliptic"
	"crypto/tls"
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"io"
	"net/http"
	"os"
	"strconv"
	"strings"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"github.com/valyala/fasthttp"

	"github.com/authelia/authelia/v4/internal/configuration/schema"
	"github.com/authelia/authelia/v4/internal/logging"
	"github.com/authelia/authelia/v4/internal/middlewares"
	"github.com/authelia/authelia/v4/internal/utils"
)

// TemporaryCertificate contains the FD of 2 temporary files containing the PEM format of the certificate and private key.
type TemporaryCertificate struct {
	CertFile *os.File
	KeyFile  *os.File

	Certificate *x509.Certificate

	CertificatePEM []byte
	KeyPEM         []byte
}

func (tc TemporaryCertificate) TLSCertificate() (tls.Certificate, error) {
	return tls.LoadX509KeyPair(tc.CertFile.Name(), tc.KeyFile.Name())
}

func (tc *TemporaryCertificate) Close() {
	if tc.CertFile != nil {
		tc.CertFile.Close()
	}

	if tc.KeyFile != nil {
		tc.KeyFile.Close()
	}
}

type CertificateContext struct {
	Certificates      []TemporaryCertificate
	privateKeyBuilder utils.PrivateKeyBuilder
}

// NewCertificateContext instantiate a new certificate context used to easily generate certificates within tests.
func NewCertificateContext(privateKeyBuilder utils.PrivateKeyBuilder) (*CertificateContext, error) {
	certificateContext := new(CertificateContext)
	certificateContext.privateKeyBuilder = privateKeyBuilder

	cert, err := certificateContext.GenerateCertificate()
	if err != nil {
		return nil, err
	}

	certificateContext.Certificates = []TemporaryCertificate{*cert}

	return certificateContext, nil
}

// GenerateCertificate generate a new certificate in the context.
func (cc *CertificateContext) GenerateCertificate() (*TemporaryCertificate, error) {
	certBytes, keyBytes, err := utils.GenerateCertificate(cc.privateKeyBuilder,
		[]string{"authelia.com", "example.org", "local.example.com"},
		time.Now(), 3*time.Hour, false)
	if err != nil {
		return nil, fmt.Errorf("unable to generate certificate: %v", err)
	}

	tmpCertificate := new(TemporaryCertificate)

	certFile, err := os.CreateTemp("", "cert")
	if err != nil {
		return nil, fmt.Errorf("unable to create temp file for certificate: %v", err)
	}

	tmpCertificate.CertFile = certFile
	tmpCertificate.CertificatePEM = certBytes

	block, _ := pem.Decode(certBytes)
	c, err := x509.ParseCertificate(block.Bytes)

	if err != nil {
		return nil, fmt.Errorf("unable to parse certificate: %v", err)
	}

	tmpCertificate.Certificate = c

	err = os.WriteFile(tmpCertificate.CertFile.Name(), certBytes, 0600)
	if err != nil {
		tmpCertificate.Close()
		return nil, fmt.Errorf("unable to write certificates in file: %v", err)
	}

	keyFile, err := os.CreateTemp("", "key")
	if err != nil {
		tmpCertificate.Close()
		return nil, fmt.Errorf("unable to create temp file for private key: %v", err)
	}

	tmpCertificate.KeyFile = keyFile
	tmpCertificate.KeyPEM = keyBytes

	err = os.WriteFile(tmpCertificate.KeyFile.Name(), keyBytes, 0600)
	if err != nil {
		tmpCertificate.Close()
		return nil, fmt.Errorf("unable to write private key in file: %v", err)
	}

	cc.Certificates = append(cc.Certificates, *tmpCertificate)

	return tmpCertificate, nil
}

func (cc *CertificateContext) Close() {
	for _, tc := range cc.Certificates {
		tc.Close()
	}
}

type TLSServerContext struct {
	server *fasthttp.Server
	port   int
}

func NewTLSServerContext(configuration schema.Configuration) (*TLSServerContext, error) {
	serverContext := new(TLSServerContext)

	s, listener, err := CreateDefaultServer(configuration, middlewares.Providers{})

	if err != nil {
		return nil, err
	}

	serverContext.server = s

	go func() {
		err := s.Serve(listener)
		if err != nil {
			logging.Logger().Fatal(err)
		}
	}()

	addrSplit := strings.Split(listener.Addr().String(), ":")
	if len(addrSplit) > 1 {
		port, err := strconv.ParseInt(addrSplit[len(addrSplit)-1], 10, 32)
		if err != nil {
			return nil, fmt.Errorf("unable to parse port from address: %v", err)
		}

		serverContext.port = int(port)
	}

	return serverContext, nil
}

func (sc *TLSServerContext) Port() int {
	return sc.port
}

func (sc *TLSServerContext) Close() error {
	return sc.server.Shutdown()
}

func TestShouldRaiseErrorWhenClientDoesNotSkipVerify(t *testing.T) {
	privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
	certificateContext, err := NewCertificateContext(privateKeyBuilder)
	require.NoError(t, err)

	defer certificateContext.Close()

	tlsServerContext, err := NewTLSServerContext(schema.Configuration{
		Server: schema.ServerConfiguration{
			TLS: schema.ServerTLSConfiguration{
				Certificate: certificateContext.Certificates[0].CertFile.Name(),
				Key:         certificateContext.Certificates[0].KeyFile.Name(),
			},
		},
	})
	require.NoError(t, err)

	defer tlsServerContext.Close()

	fmt.Println(tlsServerContext.Port())
	req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d", tlsServerContext.Port()), nil)
	require.NoError(t, err)

	_, err = http.DefaultClient.Do(req)
	require.Error(t, err)

	require.Contains(t, err.Error(), "x509: certificate signed by unknown authority")
}

func TestShouldServeOverTLSWhenClientDoesSkipVerify(t *testing.T) {
	privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
	certificateContext, err := NewCertificateContext(privateKeyBuilder)
	require.NoError(t, err)

	defer certificateContext.Close()

	tlsServerContext, err := NewTLSServerContext(schema.Configuration{
		Server: schema.ServerConfiguration{
			TLS: schema.ServerTLSConfiguration{
				Certificate: certificateContext.Certificates[0].CertFile.Name(),
				Key:         certificateContext.Certificates[0].KeyFile.Name(),
			},
		},
	})
	require.NoError(t, err)

	defer tlsServerContext.Close()

	req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
	require.NoError(t, err)

	tr := &http.Transport{
		TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // Needs to be enabled in tests. Not used in production.
	}
	client := &http.Client{Transport: tr}

	res, err := client.Do(req)
	require.NoError(t, err)

	defer res.Body.Close()

	_, err = io.ReadAll(res.Body)
	require.NoError(t, err)
	assert.Equal(t, "404 Not Found", res.Status)
}

func TestShouldServeOverTLSWhenClientHasProperRootCA(t *testing.T) {
	privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
	certificateContext, err := NewCertificateContext(privateKeyBuilder)
	require.NoError(t, err)

	defer certificateContext.Close()

	tlsServerContext, err := NewTLSServerContext(schema.Configuration{
		Server: schema.ServerConfiguration{
			TLS: schema.ServerTLSConfiguration{
				Certificate: certificateContext.Certificates[0].CertFile.Name(),
				Key:         certificateContext.Certificates[0].KeyFile.Name(),
			},
		},
	})
	require.NoError(t, err)

	defer tlsServerContext.Close()

	req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
	require.NoError(t, err)

	block, _ := pem.Decode(certificateContext.Certificates[0].CertificatePEM)
	c, err := x509.ParseCertificate(block.Bytes)
	require.NoError(t, err)

	// Create a root CA for the client to properly validate server cert.
	rootCAs := x509.NewCertPool()
	rootCAs.AddCert(c)

	tr := &http.Transport{
		TLSClientConfig: &tls.Config{
			RootCAs:    rootCAs,
			MinVersion: tls.VersionTLS13,
		},
	}
	client := &http.Client{Transport: tr}

	res, err := client.Do(req)
	require.NoError(t, err)

	defer res.Body.Close()

	_, err = io.ReadAll(res.Body)
	require.NoError(t, err)
	assert.Equal(t, "404 Not Found", res.Status)
}

func TestShouldRaiseWhenMutualTLSIsConfiguredAndClientIsNotAuthenticated(t *testing.T) {
	privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
	certificateContext, err := NewCertificateContext(privateKeyBuilder)
	require.NoError(t, err)

	defer certificateContext.Close()

	clientCert, err := certificateContext.GenerateCertificate()
	require.NoError(t, err)

	tlsServerContext, err := NewTLSServerContext(schema.Configuration{
		Server: schema.ServerConfiguration{
			TLS: schema.ServerTLSConfiguration{
				Certificate:        certificateContext.Certificates[0].CertFile.Name(),
				Key:                certificateContext.Certificates[0].KeyFile.Name(),
				ClientCertificates: []string{clientCert.CertFile.Name()},
			},
		},
	})
	require.NoError(t, err)

	defer tlsServerContext.Close()

	req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
	require.NoError(t, err)

	// Create a root CA for the client to properly validate server cert.
	rootCAs := x509.NewCertPool()
	rootCAs.AddCert(certificateContext.Certificates[0].Certificate)

	tr := &http.Transport{
		TLSClientConfig: &tls.Config{
			RootCAs:    rootCAs,
			MinVersion: tls.VersionTLS13,
		},
	}
	client := &http.Client{Transport: tr}

	_, err = client.Do(req)
	require.Error(t, err)
	assert.Contains(t, err.Error(), "remote error: tls: bad certificate")
}

func TestShouldServeProperlyWhenMutualTLSIsConfiguredAndClientIsAuthenticated(t *testing.T) {
	privateKeyBuilder := utils.ECDSAKeyBuilder{}.WithCurve(elliptic.P256())
	certificateContext, err := NewCertificateContext(privateKeyBuilder)
	require.NoError(t, err)

	defer certificateContext.Close()

	clientCert, err := certificateContext.GenerateCertificate()
	require.NoError(t, err)

	tlsServerContext, err := NewTLSServerContext(schema.Configuration{
		Server: schema.ServerConfiguration{
			TLS: schema.ServerTLSConfiguration{
				Certificate:        certificateContext.Certificates[0].CertFile.Name(),
				Key:                certificateContext.Certificates[0].KeyFile.Name(),
				ClientCertificates: []string{clientCert.CertFile.Name()},
			},
		},
	})
	require.NoError(t, err)

	defer tlsServerContext.Close()

	req, err := http.NewRequest("GET", fmt.Sprintf("https://local.example.com:%d/api/notfound", tlsServerContext.Port()), nil)
	require.NoError(t, err)

	// Create a root CA for the client to properly validate server cert.
	rootCAs := x509.NewCertPool()
	rootCAs.AddCert(certificateContext.Certificates[0].Certificate)

	cCert, err := certificateContext.Certificates[1].TLSCertificate()
	require.NoError(t, err)

	tr := &http.Transport{
		TLSClientConfig: &tls.Config{
			RootCAs:      rootCAs,
			Certificates: []tls.Certificate{cCert},
			MinVersion:   tls.VersionTLS13,
		},
	}
	client := &http.Client{Transport: tr}

	res, err := client.Do(req)
	require.NoError(t, err)

	defer res.Body.Close()

	_, err = io.ReadAll(res.Body)
	require.NoError(t, err)
	assert.Equal(t, "404 Not Found", res.Status)
}