package middlewares

import (
	"bytes"
	"net/url"
	"strconv"
	"strings"

	"github.com/valyala/fasthttp"

	"github.com/authelia/authelia/v4/internal/utils"
)

// NewCORSPolicyBuilder returns a new CORSPolicyBuilder which is used to build a CORSPolicy which adds the Vary header
// with a value reflecting that the Origin header will Vary this response, then if the Origin header has a https scheme
// it makes the following additional adjustments: copies the Origin header to the Access-Control-Allow-Origin header
// effectively allowing all origins, sets the Access-Control-Allow-Credentials header to false which disallows CORS
// requests from sending cookies etc, sets the Access-Control-Allow-Headers header to the value specified by
// Access-Control-Request-Headers in the request excluding the Cookie/Authorization/Proxy-Authorization and special *
// values, sets Access-Control-Allow-Methods to the value specified by the Access-Control-Request-Method header, sets
// the Access-Control-Max-Age header to 100.
//
// These behaviours can be overridden by the With methods on the returned policy.
func NewCORSPolicyBuilder() (policy *CORSPolicyBuilder) {
	return &CORSPolicyBuilder{
		enabled: true,
		maxAge:  100,
	}
}

// CORSPolicyBuilder is a special middleware which provides CORS headers via handlers and middleware methods which can be
// configured. It aims to simplify CORS configurations.
type CORSPolicyBuilder struct {
	enabled     bool
	varyOnly    bool
	varySet     bool
	methods     []string
	headers     []string
	origins     []string
	credentials bool
	vary        []string
	maxAge      int
}

// Build reads the CORSPolicyBuilder configuration and generates a CORSPolicy.
func (b *CORSPolicyBuilder) Build() (policy *CORSPolicy) {
	policy = &CORSPolicy{
		enabled:     b.enabled,
		varyOnly:    b.varyOnly,
		credentials: []byte(strconv.FormatBool(b.credentials)),
		origins:     b.buildOrigins(),
		headers:     b.buildHeaders(),
		vary:        b.buildVary(),
	}

	if len(b.methods) != 0 {
		policy.methods = []byte(strings.Join(b.methods, ", "))
	}

	if b.maxAge <= 0 {
		policy.maxAge = headerValueMaxAge
	} else {
		policy.maxAge = []byte(strconv.Itoa(b.maxAge))
	}

	return policy
}

func (b CORSPolicyBuilder) buildOrigins() (origins [][]byte) {
	if len(b.origins) != 0 {
		if len(b.origins) == 1 && b.origins[0] == "*" {
			origins = append(origins, []byte(b.origins[0]))
		} else {
			for _, origin := range b.origins {
				origins = append(origins, []byte(origin))
			}
		}
	}

	return origins
}

func (b CORSPolicyBuilder) buildHeaders() (headers []byte) {
	if len(b.headers) != 0 {
		h := b.headers

		if b.credentials {
			if !utils.IsStringInSliceFold(fasthttp.HeaderCookie, h) {
				h = append(h, fasthttp.HeaderCookie)
			}

			if !utils.IsStringInSliceFold(fasthttp.HeaderAuthorization, h) {
				h = append(h, fasthttp.HeaderAuthorization)
			}

			if !utils.IsStringInSliceFold(fasthttp.HeaderProxyAuthorization, h) {
				h = append(h, fasthttp.HeaderProxyAuthorization)
			}
		}

		headers = utils.JoinAndCanonicalizeHeaders(headerSeparator, h...)
	}

	return headers
}

func (b CORSPolicyBuilder) buildVary() (vary []byte) {
	if b.varySet {
		if len(b.vary) != 0 {
			vary = utils.JoinAndCanonicalizeHeaders(headerSeparator, b.vary...)
		}
	} else {
		if len(b.origins) == 1 && b.origins[0] == "*" {
			vary = headerValueVaryWildcard
		} else {
			vary = headerValueVary
		}
	}

	return vary
}

// WithEnabled changes the enabled state of the middleware. If the middleware is initialized with NewCORSPolicyBuilder this
// value will be true but this function can override the value. Setting it to false prevents the middleware from adding
// any CORS headers. The only effect this middleware has after disabling this is the HandleOPTIONS and HandleOnlyOPTIONS
// handlers still function to return a HTTP 204 No Content, with the Allow header communicating the available HTTP
// method verbs. The main benefit of this option is that you don't have to implement complex logic to add/remove the
// middleware, you can just add it with the Middleware method, and adjust it using the WithEnabled method.
func (b *CORSPolicyBuilder) WithEnabled(enabled bool) (policy *CORSPolicyBuilder) {
	b.enabled = enabled

	return b
}

// WithAllowedMethods takes a list or HTTP methods and adjusts the Access-Control-Allow-Methods header to respond with
// that value.
func (b *CORSPolicyBuilder) WithAllowedMethods(methods ...string) (policy *CORSPolicyBuilder) {
	b.methods = methods

	return b
}

// WithAllowedOrigins takes a list of origin strings and only applies the CORS policy if the origin matches one of these.
func (b *CORSPolicyBuilder) WithAllowedOrigins(origins ...string) (policy *CORSPolicyBuilder) {
	b.origins = origins

	return b
}

// WithAllowedHeaders takes a list of header strings and alters the default Access-Control-Allow-Headers header.
func (b *CORSPolicyBuilder) WithAllowedHeaders(headers ...string) (policy *CORSPolicyBuilder) {
	b.headers = headers

	return b
}

// WithAllowCredentials takes bool and alters the default Access-Control-Allow-Credentials header.
func (b *CORSPolicyBuilder) WithAllowCredentials(allow bool) (policy *CORSPolicyBuilder) {
	b.credentials = allow

	return b
}

// WithVary takes a list of header strings and alters the default Vary header.
func (b *CORSPolicyBuilder) WithVary(headers ...string) (policy *CORSPolicyBuilder) {
	b.vary = headers
	b.varySet = true

	return b
}

// WithVaryOnly just adds the Vary header.
func (b *CORSPolicyBuilder) WithVaryOnly(varyOnly bool) (policy *CORSPolicyBuilder) {
	b.varyOnly = varyOnly

	return b
}

// WithMaxAge takes an integer and alters the default Access-Control-Max-Age header.
func (b *CORSPolicyBuilder) WithMaxAge(age int) (policy *CORSPolicyBuilder) {
	b.maxAge = age

	return b
}

// CORSPolicy is a middleware that handles adding CORS headers.
type CORSPolicy struct {
	enabled     bool
	varyOnly    bool
	methods     []byte
	headers     []byte
	origins     [][]byte
	credentials []byte
	vary        []byte
	maxAge      []byte
}

// HandleOPTIONS is an OPTIONS handler that just adds CORS headers, the Allow header, and sets the status code to 204
// without a body. This handler should generally not be used without using WithAllowedMethods.
func (p CORSPolicy) HandleOPTIONS(ctx *fasthttp.RequestCtx) {
	p.handleOPTIONS(ctx)
	p.handle(ctx)
}

// HandleOnlyOPTIONS is an OPTIONS handler that just handles the Allow header, and sets the status code to 204
// without a body. This handler should generally not be used without using WithAllowedMethods.
func (p CORSPolicy) HandleOnlyOPTIONS(ctx *fasthttp.RequestCtx) {
	p.handleOPTIONS(ctx)
}

// Middleware provides a middleware that adds the appropriate CORS headers for this CORSPolicyBuilder.
func (p CORSPolicy) Middleware(next fasthttp.RequestHandler) (handler fasthttp.RequestHandler) {
	return func(ctx *fasthttp.RequestCtx) {
		p.handle(ctx)

		next(ctx)
	}
}

func (p CORSPolicy) handle(ctx *fasthttp.RequestCtx) {
	if !p.enabled {
		return
	}

	p.handleVary(ctx)

	if !p.varyOnly {
		p.handleCORS(ctx)
	}
}

func (p CORSPolicy) handleOPTIONS(ctx *fasthttp.RequestCtx) {
	ctx.Response.ResetBody()

	/* The OPTIONS method should not return a 204 as per the following specifications when read together:

	RFC7231 (https://www.rfc-editor.org/rfc/rfc7231#section-4.3.7):
		A server MUST generate a Content-Length field with a value of "0" if no payload body is to be sent in
		the response.

	RFC7230 (https://www.rfc-editor.org/rfc/rfc7230#section-3.3.2):
		A server MUST NOT send a Content-Length header field in any response with a status code of 1xx (Informational)
		or 204 (No Content).
	*/
	ctx.SetStatusCode(fasthttp.StatusOK)
	ctx.Response.Header.SetBytesKV(headerContentLength, headerValueZero)

	if len(p.methods) != 0 {
		ctx.Response.Header.SetBytesKV(headerAllow, p.methods)
	}
}

func (p CORSPolicy) handleVary(ctx *fasthttp.RequestCtx) {
	if len(p.vary) != 0 {
		ctx.Response.Header.SetBytesKV(headerVary, p.vary)
	}
}

func (p CORSPolicy) handleCORS(ctx *fasthttp.RequestCtx) {
	var (
		originURL *url.URL
		err       error
	)

	origin := ctx.Request.Header.PeekBytes(headerOrigin)

	// Skip processing of any `https` scheme URL that has not expressly been configured.
	if originURL, err = url.Parse(string(origin)); err != nil || (originURL.Scheme != "https" && p.origins == nil) {
		return
	}

	var allowedOrigin []byte

	switch len(p.origins) {
	case 0:
		allowedOrigin = origin
	default:
		for i := 0; i < len(p.origins); i++ {
			if bytes.Equal(p.origins[i], headerValueOriginWildcard) {
				allowedOrigin = headerValueOriginWildcard
			} else if bytes.Equal(p.origins[i], origin) {
				allowedOrigin = origin
			}
		}

		if len(allowedOrigin) == 0 {
			return
		}
	}

	ctx.Response.Header.SetBytesKV(headerAccessControlAllowOrigin, allowedOrigin)

	if len(p.credentials) != 0 {
		ctx.Response.Header.SetBytesKV(headerAccessControlAllowCredentials, p.credentials)
	}

	if len(p.maxAge) != 0 {
		ctx.Response.Header.SetBytesKV(headerAccessControlMaxAge, p.maxAge)
	}

	p.handleAllowedHeaders(ctx)
	p.handleAllowedMethods(ctx)
}

func (p CORSPolicy) handleAllowedMethods(ctx *fasthttp.RequestCtx) {
	switch len(p.methods) {
	case 0:
		// TODO: It may be beneficial to be able to control this automatic behaviour.
		if requestMethods := ctx.Request.Header.PeekBytes(headerAccessControlRequestMethod); requestMethods != nil {
			ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, requestMethods)
		}
	default:
		ctx.Response.Header.SetBytesKV(headerAccessControlAllowMethods, p.methods)
	}
}

func (p CORSPolicy) handleAllowedHeaders(ctx *fasthttp.RequestCtx) {
	switch len(p.headers) {
	case 0:
		// TODO: It may be beneficial to be able to control this automatic behaviour.
		if headers := ctx.Request.Header.PeekBytes(headerAccessControlRequestHeaders); headers != nil {
			requestedHeaders := strings.Split(string(headers), ",")
			allowHeaders := make([]string, 0, len(requestedHeaders))

			for i := 0; i < len(requestedHeaders); i++ {
				headerTrimmed := strings.Trim(requestedHeaders[i], " ")

				if headerTrimmed == "*" {
					continue
				}

				if bytes.Equal(p.credentials, headerValueTrue) ||
					(!strings.EqualFold(fasthttp.HeaderCookie, headerTrimmed) &&
						!strings.EqualFold(fasthttp.HeaderAuthorization, headerTrimmed) &&
						!strings.EqualFold(fasthttp.HeaderProxyAuthorization, headerTrimmed)) {
					allowHeaders = append(allowHeaders, headerTrimmed)
				}
			}

			if len(allowHeaders) != 0 {
				ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, []byte(strings.Join(allowHeaders, ", ")))
			}
		}
	default:
		ctx.Response.Header.SetBytesKV(headerAccessControlAllowHeaders, p.headers)
	}
}