import Konva from 'konva';
import { AppError } from '@nirby/js-utils/errors';

export type FlexDirection = 'horizontal' | 'vertical';
export type MainAxisAlignment =
    | 'start'
    | 'between'
    | 'end'
    | 'center'
    | 'stretch';
export type CrossAxisAlignment = 'start' | 'end' | 'center' | 'stretch';

export interface FlexChild {
    display: DisplayType;
    flex: number;
    mainAxisConstraints: AxisConstraints;
    crossAxisConstraints: AxisConstraints;
    aspectRatio?: number; // only used on fill display types
}

export interface AxisConstraints {
    min?: number;
    max?: number;
}

export interface FlexParent {
    direction: FlexDirection;
    mainAxisAlignment: MainAxisAlignment;
    crossAxisAlignment: CrossAxisAlignment;
    padding: Padding;
}

export type Flex = FlexChild & FlexParent;

export type FlexNames = Record<string, Partial<Flex>>;

export type DisplayType = 'flex' | 'free' | 'fill';

const defaultFlex: Flex = {
    display: 'flex',
    flex: 1,
    direction: 'horizontal',
    mainAxisAlignment: 'start',
    crossAxisAlignment: 'stretch',
    mainAxisConstraints: {},
    crossAxisConstraints: {},
    padding: {
        main: {
            start: 0,
            end: 0,
        },
        cross: {
            start: 0,
            end: 0,
        },
    },
};

interface SingleAxisPosition {
    start: number;
    end: number;
}

interface Padding {
    main: SingleAxisPosition;
    cross: SingleAxisPosition;
}

interface Dimensions {
    horizontal: number;
    vertical: number;
}

interface FullNodePosition {
    main: SingleAxisPosition;
    cross: SingleAxisPosition;
}

interface FlexNodeContainer {
    flex: Flex;
    node: Konva.Node;
}

export class GroupFlex {
    public constructor(
        private container: Konva.Node,
        private flex: FlexNames,
    ) {}

    private static completeFlex(flex?: Partial<Flex>): Flex {
        flex = flex ?? {};
        return {
            display: flex.display ?? 'flex',
            crossAxisAlignment:
                flex.crossAxisAlignment ?? defaultFlex.crossAxisAlignment,
            crossAxisConstraints:
                flex.crossAxisConstraints ?? defaultFlex.crossAxisConstraints,
            direction: flex.direction ?? defaultFlex.direction,
            aspectRatio: flex.aspectRatio ?? defaultFlex.aspectRatio,
            flex: flex.flex ?? defaultFlex.flex,
            mainAxisAlignment:
                flex.mainAxisAlignment ?? defaultFlex.mainAxisAlignment,
            mainAxisConstraints:
                flex.mainAxisConstraints ?? defaultFlex.mainAxisConstraints,
            padding: {
                main: {
                    start:
                        flex.padding?.main.start ??
                        defaultFlex.padding.main.start,
                    end: flex.padding?.main.end ?? defaultFlex.padding.main.end,
                },
                cross: {
                    start:
                        flex.padding?.cross.start ??
                        defaultFlex.padding.cross.start,
                    end:
                        flex.padding?.cross.end ??
                        defaultFlex.padding.cross.end,
                },
            },
        };
    }

    private static getAxisPositionSingle(
        flex: number,
        constraints: AxisConstraints,
        containerPosition: SingleAxisPosition,
        spacing: number,
        previousNodePosition: SingleAxisPosition | null = null,
        totalFlex: number,
    ): SingleAxisPosition {
        const previousPosition =
            previousNodePosition?.end ?? containerPosition.start;
        spacing = previousPosition ? spacing : 0;

        const start = Math.max(
            previousPosition + spacing,
            containerPosition.start,
        );

        let flexSize =
            (flex / totalFlex) *
            (containerPosition.end - containerPosition.start);

        if (
            constraints.min &&
            constraints.max &&
            constraints.min > constraints.max
        ) {
            throw new AppError(
                `min constraint (${constraints.min}) is greater than max constraint (${constraints.max}) in flex`,
            );
        }
        if (constraints.min) {
            flexSize = Math.max(flexSize, constraints.min);
        }
        if (constraints.max) {
            flexSize = Math.min(
                flexSize,
                constraints.max,
                containerPosition.end - previousPosition,
            );
        }

        const end = Math.min(start + flexSize, containerPosition.end);

        return {
            start,
            end,
        };
    }

    public update(): void {
        this.updateNode(this.container);
    }

    private getNodeFlex(node: Konva.Node): Flex {
        const availableFlex = node
            .name()
            .split(' ')
            .map((name) => this.flex[name])
            .filter((f) => !!f);
        return GroupFlex.completeFlex(availableFlex[0]);
    }

    private updateNode(parent: Konva.Node): void {
        const parentFlex: FlexParent = this.getNodeFlex(parent);
        const children: Konva.Node[] = (parent as Konva.Group).children ?? [];
        const childrenContainer: FlexNodeContainer[] = children
            .map((c: Konva.Node) => ({
                flex: this.getNodeFlex(c),
                node: c,
            }))
            .filter((c) => {
                return c.flex.display !== 'free';
            });
        this.getNodesPositions(childrenContainer, parentFlex, {
            horizontal: parent.width(),
            vertical: parent.height(),
        }).forEach((pos, index) => {
            const child = childrenContainer[index].node;
            const isHorizontal = parentFlex.direction === 'horizontal';

            isHorizontal ? child.x(pos.main.start) : child.y(pos.main.start);
            isHorizontal
                ? child.width(pos.main.end - pos.main.start)
                : child.height(pos.main.end - pos.main.start);

            isHorizontal ? child.y(pos.cross.start) : child.x(pos.cross.start);
            isHorizontal
                ? child.height(pos.cross.end - pos.cross.start)
                : child.width(pos.cross.end - pos.cross.start);

            this.updateNode(child);
        });
    }

    private getNodesPositions(
        nodes: FlexNodeContainer[],
        parentFlex: FlexParent,
        parentSize: Dimensions,
    ): FullNodePosition[] {
        const mainDirection = parentFlex.direction;
        const mainPadding =
            parentFlex.padding.main.end + parentFlex.padding.main.start;
        let mainAxisSize = parentSize[mainDirection] - mainPadding;
        if (mainAxisSize < 1) {
            mainAxisSize = parentSize[mainDirection];
        }

        const crossDirection: FlexDirection =
            parentFlex.direction === 'vertical' ? 'horizontal' : 'vertical';
        const crossPadding =
            parentFlex.padding.cross.end + parentFlex.padding.cross.start;
        let crossAxisSize = parentSize[crossDirection] - crossPadding;
        if (crossAxisSize < 1) {
            crossAxisSize = parentSize[crossDirection];
        }

        const parentDimensions: FullNodePosition = {
            main: {
                start: mainAxisSize < 1 ? parentFlex.padding.main.start : 0,
                end:
                    mainAxisSize < 1
                        ? parentFlex.padding.main.start + mainAxisSize
                        : parentSize[mainDirection],
            },
            cross: {
                start: crossAxisSize < 1 ? parentFlex.padding.cross.start : 0,
                end:
                    crossAxisSize < 1
                        ? parentFlex.padding.cross.start + crossAxisSize
                        : parentSize[crossDirection],
            },
        };

        const mainPositions = this.getMainAxisPositions(
            nodes,
            parentFlex.mainAxisAlignment,
            parentDimensions.main,
        );
        const crossPositions = this.getCrossAxisPositions(
            nodes,
            parentFlex.crossAxisAlignment,
            parentDimensions.cross,
        );

        nodes.forEach((n, index) => {
            // respect aspect ratio
            if (n.flex.display !== 'fill' || !n.flex.aspectRatio) return;
            const main = mainPositions[index];
            const cross = crossPositions[index];

            const mcRatio =
                mainDirection === 'horizontal'
                    ? n.flex.aspectRatio
                    : 1 / n.flex.aspectRatio;
            const mainSize = main.end - main.start;
            const crossSize = cross.end - cross.start;
            if (mainSize > crossSize * mcRatio) {
                const remainingSize = mainSize - crossSize * mcRatio;
                main.start += remainingSize / 2;
                main.end -= remainingSize / 2;
            } else {
                const remainingSize = crossSize - mainSize * mcRatio;
                cross.start += remainingSize / 2;
                cross.end -= remainingSize / 2;
            }
        });

        return mainPositions.map((_, index) => ({
            main: mainPositions[index],
            cross: crossPositions[index],
        }));
    }

    private getMainAxisPositions(
        nodesData: FlexNodeContainer[],
        alignment: MainAxisAlignment,
        containerPosition: SingleAxisPosition,
    ): SingleAxisPosition[] {
        const flexNodes = nodesData.filter((n) => n.flex.display === 'flex');
        const totalFlex = flexNodes.reduce(
            (prev, current) => prev + current.flex.flex,
            0,
        );
        let lastNodePosition: SingleAxisPosition | null = null;
        let realContainer = { ...containerPosition };
        if (alignment === 'center') {
            realContainer = this.getCenteredContainer(
                flexNodes.map((n) => n.flex.mainAxisConstraints),
                realContainer,
            );
        }
        return nodesData.map((cont) => {
            switch (cont.flex.display) {
                case 'flex':
                    lastNodePosition = GroupFlex.getAxisPositionSingle(
                        cont.flex.flex,
                        cont.flex.mainAxisConstraints,
                        realContainer,
                        0,
                        lastNodePosition,
                        totalFlex,
                    );
                    return lastNodePosition;
                case 'fill':
                    return {
                        start: containerPosition.start,
                        end: containerPosition.end,
                    } as SingleAxisPosition;
                default:
                    throw new AppError(
                        `Unknown display type ${cont.flex.display}`,
                    );
            }
        });
    }

    private getCrossAxisPositions(
        nodesData: FlexNodeContainer[],
        alignment: CrossAxisAlignment,
        containerPosition: SingleAxisPosition,
    ): SingleAxisPosition[] {
        let realContainer = { ...containerPosition };
        if (alignment === 'center') {
            realContainer = this.getCenteredContainer(
                nodesData
                    .filter((n) => n.flex.display === 'flex')
                    .map((n) => n.flex.crossAxisConstraints),
                realContainer,
            );
        }
        return nodesData.map((cont) => {
            switch (cont.flex.display) {
                case 'flex':
                    return GroupFlex.getAxisPositionSingle(
                        cont.flex.flex,
                        cont.flex.crossAxisConstraints,
                        realContainer,
                        0,
                        null,
                        cont.flex.flex,
                    );
                case 'fill':
                    return {
                        start: containerPosition.start,
                        end: containerPosition.end,
                    } as SingleAxisPosition;
                default:
                    throw new AppError(
                        `Unknown display type ${cont.flex.display}`,
                    );
            }
        });
    }

    private getCenteredContainer(
        nodesConstraints: AxisConstraints[],
        containerPosition: SingleAxisPosition,
    ): SingleAxisPosition {
        const realContainer = { ...containerPosition };
        const axisFullConstraint = nodesConstraints.reduce(
            (prev, curr) => prev + (curr.max ?? Number.POSITIVE_INFINITY),
            0,
        );
        if (Number.isFinite(axisFullConstraint)) {
            const maxSize = containerPosition.end - containerPosition.start;
            let space = (maxSize - axisFullConstraint) / 2;
            if (space < 0) {
                space = 0;
            }

            realContainer.start = containerPosition.start + space;
            realContainer.end = containerPosition.end - space;
        }
        return realContainer;
    }
}
