using System;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using AIStudio.Wpf.DiagramDesigner.Geometrys;

// Implementation taken from the JS version: https://gist.github.com/menendezpoo/4a8894c152383b9d7a870c24a04447e4
// Todo: Make it more c#, Benchmark A* vs Dijkstra, Add more options
namespace AIStudio.Wpf.DiagramDesigner
{
    public static partial class Routers
    {
        public static PointBase[] Orthogonal(IDiagramViewModel _, ConnectionViewModel link)
        {
            if (link.IsPortless)
                throw new Exception("Orthogonal router doesn't work with portless links yet");

            if (link.IsFullConnection == false)
                return Normal(_, link);

            var shapeMargin = 10;
            var globalBoundsMargin = 50;
            var spots = new List<PointBase>();
            var verticals = new List<double>();
            var horizontals = new List<double>();
            var sideA = link.SourceConnectorInfo.Orientation;
            var sideAVertical = IsVerticalSide(sideA);
            var sideB = link.SinkConnectorInfo.Orientation;
            var sideBVertical = IsVerticalSide(sideB);
            var originA = GetPortPositionBasedOnAlignment(link.SourceConnectorInfo);
            var originB = GetPortPositionBasedOnAlignment(link.SinkConnectorInfo);
            var shapeA = link.SourceConnectorInfo.DataItem.GetBounds(includePorts: true);
            var shapeB = link.SinkConnectorInfoFully.DataItem.GetBounds(includePorts: true);
            var inflatedA = shapeA.InflateRectangle(shapeMargin, shapeMargin);
            var inflatedB = shapeB.InflateRectangle(shapeMargin, shapeMargin);

            if (inflatedA.Intersects(inflatedB))
            {
                shapeMargin = 0;
                inflatedA = shapeA;
                inflatedB = shapeB;
            }

            // Curated bounds to stick to
            var bounds = inflatedA.UnionRectangle(inflatedB).InflateRectangle(globalBoundsMargin, globalBoundsMargin);

            // Add edges to rulers
            verticals.Add(inflatedA.Left);
            verticals.Add(inflatedA.Right);
            horizontals.Add(inflatedA.Top);
            horizontals.Add(inflatedA.Bottom);
            verticals.Add(inflatedB.Left);
            verticals.Add(inflatedB.Right);
            horizontals.Add(inflatedB.Top);
            horizontals.Add(inflatedB.Bottom);

            // Rulers at origins of shapes
            (sideAVertical ? verticals : horizontals).Add(sideAVertical ? originA.X : originA.Y);
            (sideBVertical ? verticals : horizontals).Add(sideBVertical ? originB.X : originB.Y);

            // Points of shape antennas
            spots.Add(GetOriginSpot(originA, sideA, shapeMargin));
            spots.Add(GetOriginSpot(originB, sideB, shapeMargin));

            // Sort rulers
            verticals.Sort();
            horizontals.Sort();

            // Create grid
            var grid = RulersToGrid(verticals, horizontals, bounds);
            var gridPoints = GridToSpots(grid, new[] { inflatedA, inflatedB });

            // Add to spots
            spots.AddRange(gridPoints);

            // Create graph
            var graph = CreateGraph(spots);

            // Origin and destination by extruding antennas
            var origin = ExtrudeCp(originA, shapeMargin, sideA);
            var destination = ExtrudeCp(originB, shapeMargin, sideB);

            var path = ShortestPath(graph, origin, destination);
            if (path.Length > 0)
            {
                return SimplifyPath(path);
            }
            else
            {
                return Normal(_, link);
            }
        }

        private static PointBase GetOriginSpot(PointBase p, ConnectorOrientation alignment, double shapeMargin)
        {
            switch (alignment)
            {
                case ConnectorOrientation.Top: return p.Add(0, -shapeMargin);
                case ConnectorOrientation.Right: return p.Add(shapeMargin, 0);
                case ConnectorOrientation.Bottom: return p.Add(0, shapeMargin);
                case ConnectorOrientation.Left: return p.Add(-shapeMargin, 0);
                default:
                    throw new NotImplementedException();
            }
        }

        private static bool IsVerticalSide(ConnectorOrientation alignment)
            => alignment == ConnectorOrientation.Top || alignment == ConnectorOrientation.Bottom; // Add others

        private static Grid RulersToGrid(List<double> verticals, List<double> horizontals, RectangleBase bounds)
        {
            var result = new Grid();
            verticals.Sort();
            horizontals.Sort();

            var lastX = bounds.Left;
            var lastY = bounds.Top;
            var column = 0;
            var row = 0;

            foreach (var y in horizontals)
            {
                foreach (var x in verticals)
                {
                    result.Set(row, column++, new RectangleBase(lastX, lastY, x, y, true));
                    lastX = x;
                }

                // Last cell of the row
                result.Set(row, column, new RectangleBase(lastX, lastY, bounds.Right, y, true));
                lastX = bounds.Left;
                lastY = y;
                column = 0;
                row++;
            }

            lastX = bounds.Left;

            // Last fow of cells
            foreach (var x in verticals)
            {
                result.Set(row, column++, new RectangleBase(lastX, lastY, x, bounds.Bottom, true));
                lastX = x;
            }

            // Last cell of last row
            result.Set(row, column, new RectangleBase(lastX, lastY, bounds.Right, bounds.Bottom, true));
            return result;
        }

        private static List<PointBase> GridToSpots(Grid grid, RectangleBase[] obstacles)
        {
            bool obstacleCollision(PointBase p) => obstacles.Where(o => o.ContainsPoint(p)).Any();

            var gridPoints = new List<PointBase>();
            foreach (var keyValuePair in grid.Data)
            {
                var row = keyValuePair.Key;
                var data = keyValuePair.Value;

                var firstRow = row == 0;
                var lastRow = row == grid.Rows - 1;

                foreach (var keyValuePair2 in data)
                {
                    var col = keyValuePair2.Key;
                    var r = keyValuePair2.Value;

                    var firstCol = col == 0;
                    var lastCol = col == grid.Columns - 1;
                    var nw = firstCol && firstRow;
                    var ne = firstRow && lastCol;
                    var se = lastRow && lastCol;
                    var sw = lastRow && firstCol;

                    if (nw || ne || se || sw)
                    {
                        gridPoints.Add(r.NorthWest);
                        gridPoints.Add(r.NorthEast);
                        gridPoints.Add(r.SouthWest);
                        gridPoints.Add(r.SouthEast);
                    }
                    else if (firstRow)
                    {
                        gridPoints.Add(r.NorthWest);
                        gridPoints.Add(r.North);
                        gridPoints.Add(r.NorthEast);
                    }
                    else if (lastRow)
                    {
                        gridPoints.Add(r.SouthEast);
                        gridPoints.Add(r.South);
                        gridPoints.Add(r.SouthWest);
                    }
                    else if (firstCol)
                    {
                        gridPoints.Add(r.NorthWest);
                        gridPoints.Add(r.West);
                        gridPoints.Add(r.SouthWest);
                    }
                    else if (lastCol)
                    {
                        gridPoints.Add(r.NorthEast);
                        gridPoints.Add(r.East);
                        gridPoints.Add(r.SouthEast);
                    }
                    else
                    {
                        gridPoints.Add(r.NorthWest);
                        gridPoints.Add(r.North);
                        gridPoints.Add(r.NorthEast);
                        gridPoints.Add(r.East);
                        gridPoints.Add(r.SouthEast);
                        gridPoints.Add(r.South);
                        gridPoints.Add(r.SouthWest);
                        gridPoints.Add(r.West);
                        gridPoints.Add(r.Center);
                    }
                }
            }

            // Reduce repeated points and filter out those who touch shapes
            return ReducePoints(gridPoints).Where(p => !obstacleCollision(p)).ToList();
        }

        private static IEnumerable<PointBase> ReducePoints(List<PointBase> points)
        {
            var map = new Dictionary<double, List<double>>();
            foreach (var p in points)
            {
                (var x, var y) = p;
                if (!map.ContainsKey(y)) map.Add(y, new List<double>());
                var arr = map[y];

                if (!arr.Contains(x)) arr.Add(x);
            }

            foreach (var keyValuePair in map)
            {
                var y = keyValuePair.Key;
                var xs = keyValuePair.Value;

                foreach (var x in xs)
                {
                    yield return new PointBase(x, y);
                }
            }
        }

        private static PointGraph CreateGraph(List<PointBase> spots)
        {
            var hotXs = new List<double>();
            var hotYs = new List<double>();
            var graph = new PointGraph();

            spots.ForEach(p => {
                (var x, var y) = p;
                if (!hotXs.Contains(x)) hotXs.Add(x);
                if (!hotYs.Contains(y)) hotYs.Add(y);
                graph.Add(p);
            });

            hotXs.Sort();
            hotYs.Sort();

            for (var i = 0; i < hotYs.Count; i++)
            {
                for (var j = 0; j < hotXs.Count; j++)
                {
                    var b = new PointBase(hotXs[j], hotYs[i]);
                    if (!graph.Has(b)) continue;

                    if (j > 0)
                    {
                        var a = new PointBase(hotXs[j - 1], hotYs[i]);

                        if (graph.Has(a))
                        {
                            graph.Connect(a, b);
                            graph.Connect(b, a);
                        }
                    }

                    if (i > 0)
                    {
                        var a = new PointBase(hotXs[j], hotYs[i - 1]);

                        if (graph.Has(a))
                        {
                            graph.Connect(a, b);
                            graph.Connect(b, a);
                        }
                    }
                }
            }

            return graph;
        }

        private static PointBase ExtrudeCp(PointBase p, double margin, ConnectorOrientation alignment)
        {
            switch (alignment)
            {
                case ConnectorOrientation.Top: return p.Add(0, -margin);
                case ConnectorOrientation.Right: return p.Add(margin, 0);
                case ConnectorOrientation.Bottom: return p.Add(0, margin);
                case ConnectorOrientation.Left: return p.Add(-margin, 0);
                default: throw new NotImplementedException();
            }
        }

        private static PointBase[] ShortestPath(PointGraph graph, PointBase origin, PointBase destination)
        {
            var originNode = graph.Get(origin);
            var destinationNode = graph.Get(destination);

            if (originNode == null || destinationNode == null)
                throw new Exception("Origin node or Destination node not found");

            graph.CalculateShortestPathFromSource(graph, originNode);
            return destinationNode.ShortestPath.Select(n => n.Data).ToArray();
        }

        private static PointBase[] SimplifyPath(PointBase[] points)
        {
            if (points.Length <= 2)
            {
                return points;
            }

            var r = new List<PointBase>() { points[0] };
            for (var i = 1; i < points.Length; i++)
            {
                var cur = points[i];
                if (i == (points.Length - 1))
                {
                    r.Add(cur);
                    break;
                }

                var prev = points[i - 1];
                var next = points[i + 1];
                var bend = GetBend(prev, cur, next);

                if (bend != "none")
                {
                    r.Add(cur);
                }
            }

            return r.ToArray();
        }

        private static string GetBend(PointBase a, PointBase b, PointBase c)
        {
            var equalX = a.X == b.X && b.X == c.X;
            var equalY = a.Y == b.Y && b.Y == c.Y;
            var segment1Horizontal = a.Y == b.Y;
            var segment1Vertical = a.X == b.X;
            var segment2Horizontal = b.Y == c.Y;
            var segment2Vertical = b.X == c.X;

            if (equalX || equalY)
            {
                return "none";
            }

            if (
                !(segment1Vertical || segment1Horizontal) ||
                !(segment2Vertical || segment2Horizontal)
            )
            {
                return "unknown";
            }

            if (segment1Horizontal && segment2Vertical)
            {
                return c.Y > b.Y ? "s" : "n";

            }
            else if (segment1Vertical && segment2Horizontal)
            {
                return c.X > b.X ? "e" : "w";
            }

            throw new Exception("Nope");
        }

        class Grid
        {
            public Grid()
            {
                Data = new Dictionary<double, Dictionary<double, RectangleBase>>();
            }

            public Dictionary<double, Dictionary<double, RectangleBase>> Data
            {
                get;
            }
            public double Rows
            {
                get; private set;
            }
            public double Columns
            {
                get; private set;
            }

            public void Set(double row, double column, RectangleBase rectangle)
            {
                Rows = Math.Max(Rows, row + 1);
                Columns = Math.Max(Columns, column + 1);

                if (!Data.ContainsKey(row))
                {
                    Data.Add(row, new Dictionary<double, RectangleBase>());
                }

                Data[row].Add(column, rectangle);
            }

            public RectangleBase Get(double row, double column)
            {
                if (!Data.ContainsKey(row))
                    return RectangleBase.Empty;

                if (!Data[row].ContainsKey(column))
                    return RectangleBase.Empty;

                return Data[row][column];
            }

            public RectangleBase[] Rectangles() => Data.SelectMany(r => r.Value.Values).ToArray();
        }

        class PointGraph
        {
            public readonly Dictionary<string, Dictionary<string, PointNode>> _index = new Dictionary<string, Dictionary<string, PointNode>>();

            public void Add(PointBase p)
            {
                (var x, var y) = p;
                var xs = x.ToInvariantString();
                var ys = y.ToInvariantString();

                if (!_index.ContainsKey(xs))
                    _index.Add(xs, new Dictionary<string, PointNode>());

                if (!_index[xs].ContainsKey(ys))
                    _index[xs].Add(ys, new PointNode(p));
            }

            private PointNode GetLowestDistanceNode(HashSet<PointNode> unsettledNodes)
            {
                PointNode lowestDistanceNode = null;
                var lowestDistance = double.MaxValue;
                foreach (var node in unsettledNodes)
                {
                    var nodeDistance = node.Distance;
                    if (nodeDistance < lowestDistance)
                    {
                        lowestDistance = nodeDistance;
                        lowestDistanceNode = node;
                    }
                }

                return lowestDistanceNode;
            }

            public PointGraph CalculateShortestPathFromSource(PointGraph graph, PointNode source)
            {
                source.Distance = 0;
                var settledNodes = new HashSet<PointNode>();
                var unsettledNodes = new HashSet<PointNode>
            {
                source
            };

                while (unsettledNodes.Count != 0)
                {
                    var currentNode = GetLowestDistanceNode(unsettledNodes);
                    unsettledNodes.Remove(currentNode);

                    foreach (var keyValuePair in currentNode.AdjacentNodes)
                    {
                        var adjacentNode = keyValuePair.Key;
                        var edgeWeight = keyValuePair.Value;
                        if (!settledNodes.Contains(adjacentNode))
                        {
                            CalculateMinimumDistance(adjacentNode, edgeWeight, currentNode);
                            unsettledNodes.Add(adjacentNode);
                        }

                    }
                    settledNodes.Add(currentNode);
                }

                return graph;
            }

            private void CalculateMinimumDistance(PointNode evaluationNode, double edgeWeight, PointNode sourceNode)
            {
                var sourceDistance = sourceNode.Distance;
                var comingDirection = InferPathDirection(sourceNode);
                var goingDirection = DirectionOfNodes(sourceNode, evaluationNode);
                var changingDirection = comingDirection != null && goingDirection != null && comingDirection != goingDirection;
                var extraWeigh = changingDirection ? Math.Pow(edgeWeight + 1, 2) : 0;

                if (sourceDistance + edgeWeight + extraWeigh < evaluationNode.Distance)
                {
                    evaluationNode.Distance = sourceDistance + edgeWeight + extraWeigh;
                    var shortestPath = new List<PointNode>();
                    shortestPath.AddRange(sourceNode.ShortestPath);
                    shortestPath.Add(sourceNode);
                    evaluationNode.ShortestPath = shortestPath;
                }
            }

            private char? DirectionOf(PointBase a, PointBase b)
            {
                if (a.X == b.X) return 'h';
                else if (a.Y == b.Y) return 'v';
                return null;
            }

            private char? DirectionOfNodes(PointNode a, PointNode b) => DirectionOf(a.Data, b.Data);

            private char? InferPathDirection(PointNode node)
            {
                if (node.ShortestPath.Count == 0)
                    return null;

                return DirectionOfNodes(node.ShortestPath[node.ShortestPath.Count - 1], node);
            }

            public void Connect(PointBase a, PointBase b)
            {
                var nodeA = Get(a);
                var nodeB = Get(b);

                if (nodeA == null || nodeB == null)
                    return;

                nodeA.AdjacentNodes.Add(nodeB, a.DistanceTo(b));
            }

            public bool Has(PointBase p)
            {
                (var x, var y) = p;
                var xs = x.ToInvariantString();
                var ys = y.ToInvariantString();
                return _index.ContainsKey(xs) && _index[xs].ContainsKey(ys);
            }

            public PointNode Get(PointBase p)
            {
                (var x, var y) = p;
                var xs = x.ToInvariantString();
                var ys = y.ToInvariantString();

                if (_index.ContainsKey(xs) && _index[xs].ContainsKey(ys))
                    return _index[xs][ys];

                return null;
            }
        }

        class PointNode
        {
            public PointNode(PointBase data)
            {
                Data = data;
            }

            public PointBase Data
            {
                get;
            }
            public double Distance { get; set; } = double.MaxValue;
            public List<PointNode> ShortestPath { get; set; } = new List<PointNode>();
            public Dictionary<PointNode, double> AdjacentNodes { get; set; } = new Dictionary<PointNode, double>();
        }
    }


}