Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import React, { useMemo } from 'react';
import {
BaseEdge,
EdgeLabelRenderer,
getSmoothStepPath,
getBezierPath,
useStore,
} from '@xyflow/react';
import type { EdgeProps, ReactFlowState } from '@xyflow/react';
Expand All @@ -20,6 +20,7 @@ type SpecialPathParams = Pick<
'sourceX' | 'sourceY' | 'targetX' | 'targetY'
>;

// Calculate offset for edges sharing the same source-target pair (bidirectional)
const getBidirectionalOffset = ({
sourceX,
sourceY,
Expand Down Expand Up @@ -104,54 +105,101 @@ const CustomEdge: React.FC<CustomEdgeProps> = (props) => {
selected,
markerEnd,
style,
pathOptions,
} = props;

const { isBidirectional, shouldCurve } = useStore((state: ReactFlowState) => {
const partner = state.edges.find(
(edge) =>
edge.id !== id && edge.source === target && edge.target === source,
);
// Detect bidirectional edges AND edges sharing the same source
const { isBidirectional, shouldCurve, edgeIndex, totalSiblingEdges } = useStore(
(state: ReactFlowState) => {
// Check for bidirectional partner (reverse edge)
const partner = state.edges.find(
(edge) =>
edge.id !== id && edge.source === target && edge.target === source,
);

if (!partner) {
return { isBidirectional: false, shouldCurve: false };
}
// Find all edges from the same source (sibling edges)
const siblingEdges = state.edges.filter(
(edge) => edge.source === source && edge.id !== id,
);

const shouldCurve = id.localeCompare(partner.id) > 0;
// Find index of this edge among all edges from the same source
const allEdgesFromSource = state.edges.filter(
(edge) => edge.source === source,
);
const myIndex = allEdgesFromSource.findIndex((edge) => edge.id === id);

return { isBidirectional: true, shouldCurve };
});
if (!partner) {
return {
isBidirectional: false,
shouldCurve: siblingEdges.length > 0, // Curve if there are sibling edges
edgeIndex: myIndex,
totalSiblingEdges: allEdgesFromSource.length,
};
}

const shouldCurve = id.localeCompare(partner.id) > 0;

const [defaultPath, defaultLabelX, defaultLabelY] = getSmoothStepPath({
return {
isBidirectional: true,
shouldCurve,
edgeIndex: myIndex,
totalSiblingEdges: allEdgesFromSource.length,
};
},
);

// Calculate curvature based on edge index to prevent overlapping
const curvatureOffset = useMemo(() => {
if (totalSiblingEdges <= 1) return 0;

// Distribute edges evenly with different curve offsets
const baseOffset = 25; // Base offset between edges
const centerIndex = (totalSiblingEdges - 1) / 2;
return (edgeIndex - centerIndex) * baseOffset;
}, [edgeIndex, totalSiblingEdges]);

// Use bezier path for smoother curves
const [defaultPath, defaultLabelX, defaultLabelY] = getBezierPath({
sourceX,
sourceY,
targetX,
targetY,
sourcePosition,
targetPosition,
borderRadius: pathOptions?.borderRadius ?? 10,
curvature: 0.25 + Math.abs(curvatureOffset) * 0.01, // Adjust curvature based on offset
});

let edgePath = defaultPath;
let labelX = defaultLabelX;
let labelY = defaultLabelY;

if (shouldCurve) {
const offset = getBidirectionalOffset({
sourceX,
sourceY,
targetX,
targetY,
});
edgePath = getSpecialPath({ sourceX, sourceY, targetX, targetY }, offset);
const labelPosition = getBidirectionalLabelPosition(
{ sourceX, sourceY, targetX, targetY },
offset,
);
labelX = labelPosition.x;
labelY = labelPosition.y;
} else if (isBidirectional) {
// keep default smooth step path for the partner edge
// Apply additional offset for bidirectional or sibling edges
if (shouldCurve || totalSiblingEdges > 1) {
let offset: number;

if (isBidirectional && shouldCurve) {
// Bidirectional pair - use existing logic
offset = getBidirectionalOffset({
sourceX,
sourceY,
targetX,
targetY,
});
} else if (totalSiblingEdges > 1) {
// Multiple edges from same source - apply index-based offset
offset = curvatureOffset;
} else {
offset = 0;
}

if (offset !== 0) {
edgePath = getSpecialPath({ sourceX, sourceY, targetX, targetY }, offset);
const labelPosition = getBidirectionalLabelPosition(
{ sourceX, sourceY, targetX, targetY },
offset,
);
labelX = labelPosition.x;
labelY = labelPosition.y;
}
}

const truncatedLabel = useMemo(() => {
Expand Down Expand Up @@ -199,6 +247,7 @@ const CustomEdge: React.FC<CustomEdgeProps> = (props) => {

const edgeColor = selected ? '#0078d4' : '#b1b1b7';
const edgeWidth = selected ? 2 : 1.5;

return (
<>
<BaseEdge
Expand Down