using System; using System.Collections.Generic; using System.Data; using System.Linq; using AIStudio.Wpf.DiagramDesigner.Geometrys; // Implementation taken from the JS version: https://gist.github.com/menendezpoo/4a8894c152383b9d7a870c24a04447e4 // Todo: Make it more c#, Benchmark A* vs Dijkstra, Add more options namespace AIStudio.Wpf.DiagramDesigner { public static partial class Routers { public static PointBase[] Orthogonal(IDiagramViewModel _, ConnectionViewModel link) { if (link.IsPortless) throw new Exception("Orthogonal router doesn't work with portless links yet"); if (link.IsFullConnection == false) return Normal(_, link); var shapeMargin = 10; var globalBoundsMargin = 50; var spots = new List<PointBase>(); var verticals = new List<double>(); var horizontals = new List<double>(); var sideA = link.SourceConnectorInfo.Orientation; var sideAVertical = IsVerticalSide(sideA); var sideB = link.SinkConnectorInfo.Orientation; var sideBVertical = IsVerticalSide(sideB); var originA = GetPortPositionBasedOnAlignment(link.SourceConnectorInfo); var originB = GetPortPositionBasedOnAlignment(link.SinkConnectorInfo); var shapeA = link.SourceConnectorInfo.DataItem.GetBounds(includePorts: true); var shapeB = link.SinkConnectorInfoFully.DataItem.GetBounds(includePorts: true); var inflatedA = shapeA.InflateRectangle(shapeMargin, shapeMargin); var inflatedB = shapeB.InflateRectangle(shapeMargin, shapeMargin); if (inflatedA.Intersects(inflatedB)) { shapeMargin = 0; inflatedA = shapeA; inflatedB = shapeB; } // Curated bounds to stick to var bounds = inflatedA.UnionRectangle(inflatedB).InflateRectangle(globalBoundsMargin, globalBoundsMargin); // Add edges to rulers verticals.Add(inflatedA.Left); verticals.Add(inflatedA.Right); horizontals.Add(inflatedA.Top); horizontals.Add(inflatedA.Bottom); verticals.Add(inflatedB.Left); verticals.Add(inflatedB.Right); horizontals.Add(inflatedB.Top); horizontals.Add(inflatedB.Bottom); // Rulers at origins of shapes (sideAVertical ? verticals : horizontals).Add(sideAVertical ? originA.X : originA.Y); (sideBVertical ? verticals : horizontals).Add(sideBVertical ? originB.X : originB.Y); // Points of shape antennas spots.Add(GetOriginSpot(originA, sideA, shapeMargin)); spots.Add(GetOriginSpot(originB, sideB, shapeMargin)); // Sort rulers verticals.Sort(); horizontals.Sort(); // Create grid var grid = RulersToGrid(verticals, horizontals, bounds); var gridPoints = GridToSpots(grid, new[] { inflatedA, inflatedB }); // Add to spots spots.AddRange(gridPoints); // Create graph var graph = CreateGraph(spots); // Origin and destination by extruding antennas var origin = ExtrudeCp(originA, shapeMargin, sideA); var destination = ExtrudeCp(originB, shapeMargin, sideB); var path = ShortestPath(graph, origin, destination); if (path.Length > 0) { return SimplifyPath(path); } else { return Normal(_, link); } } private static PointBase GetOriginSpot(PointBase p, ConnectorOrientation alignment, double shapeMargin) { switch (alignment) { case ConnectorOrientation.Top: return p.Add(0, -shapeMargin); case ConnectorOrientation.Right: return p.Add(shapeMargin, 0); case ConnectorOrientation.Bottom: return p.Add(0, shapeMargin); case ConnectorOrientation.Left: return p.Add(-shapeMargin, 0); default: throw new NotImplementedException(); } } private static bool IsVerticalSide(ConnectorOrientation alignment) => alignment == ConnectorOrientation.Top || alignment == ConnectorOrientation.Bottom; // Add others private static Grid RulersToGrid(List<double> verticals, List<double> horizontals, RectangleBase bounds) { var result = new Grid(); verticals.Sort(); horizontals.Sort(); var lastX = bounds.Left; var lastY = bounds.Top; var column = 0; var row = 0; foreach (var y in horizontals) { foreach (var x in verticals) { result.Set(row, column++, new RectangleBase(lastX, lastY, x, y, true)); lastX = x; } // Last cell of the row result.Set(row, column, new RectangleBase(lastX, lastY, bounds.Right, y, true)); lastX = bounds.Left; lastY = y; column = 0; row++; } lastX = bounds.Left; // Last fow of cells foreach (var x in verticals) { result.Set(row, column++, new RectangleBase(lastX, lastY, x, bounds.Bottom, true)); lastX = x; } // Last cell of last row result.Set(row, column, new RectangleBase(lastX, lastY, bounds.Right, bounds.Bottom, true)); return result; } private static List<PointBase> GridToSpots(Grid grid, RectangleBase[] obstacles) { bool obstacleCollision(PointBase p) => obstacles.Where(o => o.ContainsPoint(p)).Any(); var gridPoints = new List<PointBase>(); foreach (var keyValuePair in grid.Data) { var row = keyValuePair.Key; var data = keyValuePair.Value; var firstRow = row == 0; var lastRow = row == grid.Rows - 1; foreach (var keyValuePair2 in data) { var col = keyValuePair2.Key; var r = keyValuePair2.Value; var firstCol = col == 0; var lastCol = col == grid.Columns - 1; var nw = firstCol && firstRow; var ne = firstRow && lastCol; var se = lastRow && lastCol; var sw = lastRow && firstCol; if (nw || ne || se || sw) { gridPoints.Add(r.NorthWest); gridPoints.Add(r.NorthEast); gridPoints.Add(r.SouthWest); gridPoints.Add(r.SouthEast); } else if (firstRow) { gridPoints.Add(r.NorthWest); gridPoints.Add(r.North); gridPoints.Add(r.NorthEast); } else if (lastRow) { gridPoints.Add(r.SouthEast); gridPoints.Add(r.South); gridPoints.Add(r.SouthWest); } else if (firstCol) { gridPoints.Add(r.NorthWest); gridPoints.Add(r.West); gridPoints.Add(r.SouthWest); } else if (lastCol) { gridPoints.Add(r.NorthEast); gridPoints.Add(r.East); gridPoints.Add(r.SouthEast); } else { gridPoints.Add(r.NorthWest); gridPoints.Add(r.North); gridPoints.Add(r.NorthEast); gridPoints.Add(r.East); gridPoints.Add(r.SouthEast); gridPoints.Add(r.South); gridPoints.Add(r.SouthWest); gridPoints.Add(r.West); gridPoints.Add(r.Center); } } } // Reduce repeated points and filter out those who touch shapes return ReducePoints(gridPoints).Where(p => !obstacleCollision(p)).ToList(); } private static IEnumerable<PointBase> ReducePoints(List<PointBase> points) { var map = new Dictionary<double, List<double>>(); foreach (var p in points) { (var x, var y) = p; if (!map.ContainsKey(y)) map.Add(y, new List<double>()); var arr = map[y]; if (!arr.Contains(x)) arr.Add(x); } foreach (var keyValuePair in map) { var y = keyValuePair.Key; var xs = keyValuePair.Value; foreach (var x in xs) { yield return new PointBase(x, y); } } } private static PointGraph CreateGraph(List<PointBase> spots) { var hotXs = new List<double>(); var hotYs = new List<double>(); var graph = new PointGraph(); spots.ForEach(p => { (var x, var y) = p; if (!hotXs.Contains(x)) hotXs.Add(x); if (!hotYs.Contains(y)) hotYs.Add(y); graph.Add(p); }); hotXs.Sort(); hotYs.Sort(); for (var i = 0; i < hotYs.Count; i++) { for (var j = 0; j < hotXs.Count; j++) { var b = new PointBase(hotXs[j], hotYs[i]); if (!graph.Has(b)) continue; if (j > 0) { var a = new PointBase(hotXs[j - 1], hotYs[i]); if (graph.Has(a)) { graph.Connect(a, b); graph.Connect(b, a); } } if (i > 0) { var a = new PointBase(hotXs[j], hotYs[i - 1]); if (graph.Has(a)) { graph.Connect(a, b); graph.Connect(b, a); } } } } return graph; } private static PointBase ExtrudeCp(PointBase p, double margin, ConnectorOrientation alignment) { switch (alignment) { case ConnectorOrientation.Top: return p.Add(0, -margin); case ConnectorOrientation.Right: return p.Add(margin, 0); case ConnectorOrientation.Bottom: return p.Add(0, margin); case ConnectorOrientation.Left: return p.Add(-margin, 0); default: throw new NotImplementedException(); } } private static PointBase[] ShortestPath(PointGraph graph, PointBase origin, PointBase destination) { var originNode = graph.Get(origin); var destinationNode = graph.Get(destination); if (originNode == null || destinationNode == null) throw new Exception("Origin node or Destination node not found"); graph.CalculateShortestPathFromSource(graph, originNode); return destinationNode.ShortestPath.Select(n => n.Data).ToArray(); } private static PointBase[] SimplifyPath(PointBase[] points) { if (points.Length <= 2) { return points; } var r = new List<PointBase>() { points[0] }; for (var i = 1; i < points.Length; i++) { var cur = points[i]; if (i == (points.Length - 1)) { r.Add(cur); break; } var prev = points[i - 1]; var next = points[i + 1]; var bend = GetBend(prev, cur, next); if (bend != "none") { r.Add(cur); } } return r.ToArray(); } private static string GetBend(PointBase a, PointBase b, PointBase c) { var equalX = a.X == b.X && b.X == c.X; var equalY = a.Y == b.Y && b.Y == c.Y; var segment1Horizontal = a.Y == b.Y; var segment1Vertical = a.X == b.X; var segment2Horizontal = b.Y == c.Y; var segment2Vertical = b.X == c.X; if (equalX || equalY) { return "none"; } if ( !(segment1Vertical || segment1Horizontal) || !(segment2Vertical || segment2Horizontal) ) { return "unknown"; } if (segment1Horizontal && segment2Vertical) { return c.Y > b.Y ? "s" : "n"; } else if (segment1Vertical && segment2Horizontal) { return c.X > b.X ? "e" : "w"; } throw new Exception("Nope"); } class Grid { public Grid() { Data = new Dictionary<double, Dictionary<double, RectangleBase>>(); } public Dictionary<double, Dictionary<double, RectangleBase>> Data { get; } public double Rows { get; private set; } public double Columns { get; private set; } public void Set(double row, double column, RectangleBase rectangle) { Rows = Math.Max(Rows, row + 1); Columns = Math.Max(Columns, column + 1); if (!Data.ContainsKey(row)) { Data.Add(row, new Dictionary<double, RectangleBase>()); } Data[row].Add(column, rectangle); } public RectangleBase Get(double row, double column) { if (!Data.ContainsKey(row)) return RectangleBase.Empty; if (!Data[row].ContainsKey(column)) return RectangleBase.Empty; return Data[row][column]; } public RectangleBase[] Rectangles() => Data.SelectMany(r => r.Value.Values).ToArray(); } class PointGraph { public readonly Dictionary<string, Dictionary<string, PointNode>> _index = new Dictionary<string, Dictionary<string, PointNode>>(); public void Add(PointBase p) { (var x, var y) = p; var xs = x.ToInvariantString(); var ys = y.ToInvariantString(); if (!_index.ContainsKey(xs)) _index.Add(xs, new Dictionary<string, PointNode>()); if (!_index[xs].ContainsKey(ys)) _index[xs].Add(ys, new PointNode(p)); } private PointNode GetLowestDistanceNode(HashSet<PointNode> unsettledNodes) { PointNode lowestDistanceNode = null; var lowestDistance = double.MaxValue; foreach (var node in unsettledNodes) { var nodeDistance = node.Distance; if (nodeDistance < lowestDistance) { lowestDistance = nodeDistance; lowestDistanceNode = node; } } return lowestDistanceNode; } public PointGraph CalculateShortestPathFromSource(PointGraph graph, PointNode source) { source.Distance = 0; var settledNodes = new HashSet<PointNode>(); var unsettledNodes = new HashSet<PointNode> { source }; while (unsettledNodes.Count != 0) { var currentNode = GetLowestDistanceNode(unsettledNodes); unsettledNodes.Remove(currentNode); foreach (var keyValuePair in currentNode.AdjacentNodes) { var adjacentNode = keyValuePair.Key; var edgeWeight = keyValuePair.Value; if (!settledNodes.Contains(adjacentNode)) { CalculateMinimumDistance(adjacentNode, edgeWeight, currentNode); unsettledNodes.Add(adjacentNode); } } settledNodes.Add(currentNode); } return graph; } private void CalculateMinimumDistance(PointNode evaluationNode, double edgeWeight, PointNode sourceNode) { var sourceDistance = sourceNode.Distance; var comingDirection = InferPathDirection(sourceNode); var goingDirection = DirectionOfNodes(sourceNode, evaluationNode); var changingDirection = comingDirection != null && goingDirection != null && comingDirection != goingDirection; var extraWeigh = changingDirection ? Math.Pow(edgeWeight + 1, 2) : 0; if (sourceDistance + edgeWeight + extraWeigh < evaluationNode.Distance) { evaluationNode.Distance = sourceDistance + edgeWeight + extraWeigh; var shortestPath = new List<PointNode>(); shortestPath.AddRange(sourceNode.ShortestPath); shortestPath.Add(sourceNode); evaluationNode.ShortestPath = shortestPath; } } private char? DirectionOf(PointBase a, PointBase b) { if (a.X == b.X) return 'h'; else if (a.Y == b.Y) return 'v'; return null; } private char? DirectionOfNodes(PointNode a, PointNode b) => DirectionOf(a.Data, b.Data); private char? InferPathDirection(PointNode node) { if (node.ShortestPath.Count == 0) return null; return DirectionOfNodes(node.ShortestPath[node.ShortestPath.Count - 1], node); } public void Connect(PointBase a, PointBase b) { var nodeA = Get(a); var nodeB = Get(b); if (nodeA == null || nodeB == null) return; nodeA.AdjacentNodes.Add(nodeB, a.DistanceTo(b)); } public bool Has(PointBase p) { (var x, var y) = p; var xs = x.ToInvariantString(); var ys = y.ToInvariantString(); return _index.ContainsKey(xs) && _index[xs].ContainsKey(ys); } public PointNode Get(PointBase p) { (var x, var y) = p; var xs = x.ToInvariantString(); var ys = y.ToInvariantString(); if (_index.ContainsKey(xs) && _index[xs].ContainsKey(ys)) return _index[xs][ys]; return null; } } class PointNode { public PointNode(PointBase data) { Data = data; } public PointBase Data { get; } public double Distance { get; set; } = double.MaxValue; public List<PointNode> ShortestPath { get; set; } = new List<PointNode>(); public Dictionary<PointNode, double> AdjacentNodes { get; set; } = new Dictionary<PointNode, double>(); } } }