Skip to content

Dynamic Time Warping

lineagetree.measure

Functions:

Name Description
calculate_dtw

Calculate DTW distance between two chains

calculate_dtw

calculate_dtw(
    lT: LineageTree,
    nodes1: int,
    nodes2: int,
    threshold: int = 1000,
    regist: bool = True,
    start_d: int = 0,
    back_d: int = 0,
    fast: bool = False,
    w: int = 0,
    centered_band: bool = True,
    cost_mat_p: bool = False,
) -> (
    tuple[float, tuple, np.ndarray, np.ndarray, np.ndarray]
    | tuple[float, tuple]
)

Calculate DTW distance between two chains

Parameters:

Name Type Description Default
lT
LineageTree

The LineageTree instance.

required
nodes1
int

node to compare distance

required
nodes2
int

node to compare distance

required
threshold
int

set a maximum number of points a chain can have

1000
regist
bool

Rotate and translate trajectories

True
start_d
int

start delay

0
back_d
int

end delay

0
fast
bool

if True, the algorithm will use a faster version but might not find the optimal alignment

False
w
int

window size

0
centered_band
bool

when running the fast algorithm, True if the windown is centered

True
cost_mat_p
bool

True if print the not normalized cost matrix

False

Returns:

Type Description
float

DTW distance

tuple of tuples

Aligment path

matrix

Cost matrix

list of lists

rotated and translated trajectories positions

list of lists

rotated and translated trajectories positions

Source code in src/lineagetree/measure/dynamic_time_warping.py
def calculate_dtw(
    lT: LineageTree,
    nodes1: int,
    nodes2: int,
    threshold: int = 1000,
    regist: bool = True,
    start_d: int = 0,
    back_d: int = 0,
    fast: bool = False,
    w: int = 0,
    centered_band: bool = True,
    cost_mat_p: bool = False,
) -> (
    tuple[float, tuple, np.ndarray, np.ndarray, np.ndarray]
    | tuple[float, tuple]
):
    """
    Calculate DTW distance between two chains

    Parameters
    ----------
    lT : LineageTree
        The LineageTree instance.
    nodes1 : int
        node to compare distance
    nodes2 : int
        node to compare distance
    threshold : int, default=1000
        set a maximum number of points a chain can have
    regist : bool, default=True
        Rotate and translate trajectories
    start_d : int, default=0
        start delay
    back_d : int, default=0
        end delay
    fast : bool, default=False
        if `True`, the algorithm will use a faster version but might not find the optimal alignment
    w : int, default=0
        window size
    centered_band : bool, default=True
        when running the fast algorithm, `True` if the windown is centered
    cost_mat_p : bool, default=False
        True if print the not normalized cost matrix

    Returns
    -------
    float
        DTW distance
    tuple of tuples
        Aligment path
    matrix
        Cost matrix
    list of lists
        rotated and translated trajectories positions
    list of lists
        rotated and translated trajectories positions
    """
    nodes1_chain = lT.get_chain_of_node(nodes1)
    nodes2_chain = lT.get_chain_of_node(nodes2)

    interp_chain1, interp_chain2 = __interpolate(
        lT, nodes1_chain, nodes2_chain, threshold
    )

    pos_chain1 = np.array([lT.pos[c_id] for c_id in nodes1_chain])
    pos_chain2 = np.array([lT.pos[c_id] for c_id in nodes2_chain])

    if regist:
        R, t = __rigid_transform_3D(
            np.transpose(interp_chain1), np.transpose(interp_chain2)
        )
        pos_chain1 = np.transpose(np.dot(R, pos_chain1.T) + t)

    dist_mat = distance.cdist(pos_chain1, pos_chain2, "euclidean")

    path, cost_mat, final_cost = __dp(
        dist_mat,
        start_d,
        back_d,
        w=w,
        fast=fast,
        centered_band=centered_band,
    )
    cost = final_cost / len(path)

    if cost_mat_p:
        return cost, path, cost_mat, pos_chain1, pos_chain2
    else:
        return cost, path

lineagetree.plot

Functions:

Name Description
plot_dtw_heatmap

Plot DTW cost matrix between two chains in heatmap format

plot_dtw_trajectory

Plots DTW trajectories aligment between two chains in 2D or 3D

plot_dtw_heatmap

plot_dtw_heatmap(
    lT: LineageTree,
    nodes1: int,
    nodes2: int,
    threshold: int = 1000,
    regist: bool = True,
    start_d: int = 0,
    back_d: int = 0,
    fast: bool = False,
    w: int = 0,
    centered_band: bool = True,
) -> tuple[float, plt.Figure]

Plot DTW cost matrix between two chains in heatmap format

Parameters:

Name Type Description Default
lT
LineageTree

The LineageTree instance.

required
nodes1
int

node to compare distance

required
nodes2
int

node to compare distance

required
threshold
int

set a maximum number of points a chain can have

1000
regist
bool

Rotate and translate trajectories

True
start_d
int

start delay

0
back_d
int

end delay

0
fast
bool

if True, the algorithm will use a faster version but might not find the optimal alignment

False
w
int

window size

0
centered_band
bool

when running the fast algorithm, True if the windown is centered

True

Returns:

Type Description
float

DTW distance

Figure

Heatmap of cost matrix with opitimal path

Source code in src/lineagetree/plot.py
def plot_dtw_heatmap(
    lT: LineageTree,
    nodes1: int,
    nodes2: int,
    threshold: int = 1000,
    regist: bool = True,
    start_d: int = 0,
    back_d: int = 0,
    fast: bool = False,
    w: int = 0,
    centered_band: bool = True,
) -> tuple[float, plt.Figure]:
    """
    Plot DTW cost matrix between two chains in heatmap format

    Parameters
    ----------
    lT : LineageTree
        The LineageTree instance.
    nodes1 : int
        node to compare distance
    nodes2 : int
        node to compare distance
    threshold : int, default=1000
        set a maximum number of points a chain can have
    regist : bool, default=True
        Rotate and translate trajectories
    start_d : int, default=0
        start delay
    back_d : int, default=0
        end delay
    fast : bool, default=False
        if `True`, the algorithm will use a faster version but might not find the optimal alignment
    w : int, default=0
        window size
    centered_band : bool, default=True
        when running the fast algorithm, `True` if the windown is centered

    Returns
    -------
    float
        DTW distance
    plt.Figure
        Heatmap of cost matrix with opitimal path
    """
    cost, path, cost_mat, pos_chain1, pos_chain2 = lT.calculate_dtw(
        nodes1,
        nodes2,
        threshold,
        regist,
        start_d,
        back_d,
        fast,
        w,
        centered_band,
        cost_mat_p=True,
    )

    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(1, 1, 1)
    im = ax.imshow(
        cost_mat, cmap="viridis", origin="lower", interpolation="nearest"
    )
    plt.colorbar(im)
    ax.set_title("Heatmap of DTW Cost Matrix")
    ax.set_xlabel("Tree 1")
    ax.set_ylabel("tree 2")
    x_path, y_path = zip(*path, strict=True)
    ax.plot(y_path, x_path, color="black")

    return cost, fig

plot_dtw_trajectory

plot_dtw_trajectory(
    lT: LineageTree,
    nodes1: int,
    nodes2: int,
    threshold: int = 1000,
    regist: bool = True,
    start_d: int = 0,
    back_d: int = 0,
    fast: bool = False,
    w: int = 0,
    centered_band: bool = True,
    projection: Literal[
        "3d", "xy", "xz", "yz", "pca", None
    ] = None,
    alig: bool = False,
) -> tuple[float, plt.Figure]

Plots DTW trajectories aligment between two chains in 2D or 3D

Parameters:

Name Type Description Default
lT
LineageTree

The LineageTree instance.

required
nodes1
int

node to compare distance

required
nodes2
int

node to compare distance

required
threshold
int

set a maximum number of points a chain can have

1000
regist
bool

Rotate and translate trajectories

True
start_d
int

start delay

0
back_d
int

end delay

0
w
int

window size

0
fast
bool

True if the user wants to run the fast algorithm with window restrains

False
centered_band
bool

if running the fast algorithm, True if the windown is centered

True
projection
('3d', 'xy', 'xz', 'yz', 'pca')

specify which 2D to plot -> "3d" : for the 3d visualization "xy" or None (default) : 2D projection of axis x and y "xz" : 2D projection of axis x and z "yz" : 2D projection of axis y and z "pca" : PCA projection

"3d"
alig
bool

True to show alignment on plot

False

Returns:

Type Description
float

DTW distance

figure

Trajectories Plot

Source code in src/lineagetree/plot.py
def plot_dtw_trajectory(
    lT: LineageTree,
    nodes1: int,
    nodes2: int,
    threshold: int = 1000,
    regist: bool = True,
    start_d: int = 0,
    back_d: int = 0,
    fast: bool = False,
    w: int = 0,
    centered_band: bool = True,
    projection: Literal["3d", "xy", "xz", "yz", "pca", None] = None,
    alig: bool = False,
) -> tuple[float, plt.Figure]:
    """
    Plots DTW trajectories aligment between two chains in 2D or 3D

    Parameters
    ----------
    lT : LineageTree
        The LineageTree instance.
    nodes1 : int
        node to compare distance
    nodes2 : int
        node to compare distance
    threshold : int, default=1000
        set a maximum number of points a chain can have
    regist : bool, default=True
        Rotate and translate trajectories
    start_d : int, default=0
        start delay
    back_d : int, default=0
        end delay
    w : int, default=0
        window size
    fast : bool, default=False
        True if the user wants to run the fast algorithm with window restrains
    centered_band : bool, default=True
        if running the fast algorithm, True if the windown is centered
    projection : {"3d", "xy", "xz", "yz", "pca"}, optional
        specify which 2D to plot ->
        "3d" : for the 3d visualization
        "xy" or None (default) : 2D projection of axis x and y
        "xz" : 2D projection of axis x and z
        "yz" : 2D projection of axis y and z
        "pca" : PCA projection
    alig : bool
        True to show alignment on plot

    Returns
    -------
    float
        DTW distance
    figure
        Trajectories Plot
    """
    (
        distance,
        alignment,
        cost_mat,
        pos_chain1,
        pos_chain2,
    ) = lT.calculate_dtw(
        nodes1,
        nodes2,
        threshold,
        regist,
        start_d,
        back_d,
        fast,
        w,
        centered_band,
        cost_mat_p=True,
    )

    fig = plt.figure(figsize=(10, 6))

    if projection == "3d":
        ax = fig.add_subplot(1, 1, 1, projection="3d")
    else:
        ax = fig.add_subplot(1, 1, 1)

    if projection == "3d":
        ax.plot(
            pos_chain1[:, 0],
            pos_chain1[:, 1],
            pos_chain1[:, 2],
            "-",
            label=f"root = {nodes1}",
        )
        ax.plot(
            pos_chain2[:, 0],
            pos_chain2[:, 1],
            pos_chain2[:, 2],
            "-",
            label=f"root = {nodes2}",
        )
        ax.set_ylabel("y position")
        ax.set_xlabel("x position")
        ax.set_zlabel("z position")
    else:
        if projection == "xy" or projection == "yx" or projection is None:
            __plot_2d(
                pos_chain1,
                pos_chain2,
                nodes1,
                nodes2,
                ax,
                0,
                1,
                "x position",
                "y position",
            )
        elif projection == "xz" or projection == "zx":
            __plot_2d(
                pos_chain1,
                pos_chain2,
                nodes1,
                nodes2,
                ax,
                0,
                2,
                "x position",
                "z position",
            )
        elif projection == "yz" or projection == "zy":
            __plot_2d(
                pos_chain1,
                pos_chain2,
                nodes1,
                nodes2,
                ax,
                1,
                2,
                "y position",
                "z position",
            )
        elif projection == "pca":
            try:
                from sklearn.decomposition import PCA
            except ImportError:
                Warning(
                    "scikit-learn is not installed, the PCA orientation cannot be used."
                    "You can install scikit-learn with pip install"
                )

            # Apply PCA
            pca = PCA(n_components=2)
            pca.fit(np.vstack([pos_chain1, pos_chain2]))
            pos_chain1_2d = pca.transform(pos_chain1)
            pos_chain2_2d = pca.transform(pos_chain2)

            ax.plot(
                pos_chain1_2d[:, 0],
                pos_chain1_2d[:, 1],
                "-",
                label=f"root = {nodes1}",
            )
            ax.plot(
                pos_chain2_2d[:, 0],
                pos_chain2_2d[:, 1],
                "-",
                label=f"root = {nodes2}",
            )

            # Set axis labels
            axes = ["x", "y", "z"]
            x_label = axes[np.argmax(np.abs(pca.components_[0]))]
            y_label = axes[np.argmax(np.abs(pca.components_[1]))]
            x_percent = 100 * (
                np.max(np.abs(pca.components_[0]))
                / np.sum(np.abs(pca.components_[0]))
            )
            y_percent = 100 * (
                np.max(np.abs(pca.components_[1]))
                / np.sum(np.abs(pca.components_[1]))
            )
            ax.set_xlabel(f"{x_percent:.0f}% of {x_label} position")
            ax.set_ylabel(f"{y_percent:.0f}% of {y_label} position")
        else:
            raise ValueError(
                """Error: available projections are:
                    '3d' : for the 3d visualization
                    'xy' or None (default) : 2D projection of axis x and y
                    'xz' : 2D projection of axis x and z
                    'yz' : 2D projection of axis y and z
                    'pca' : PCA projection"""
            )

    connections = [[pos_chain1[i], pos_chain2[j]] for i, j in alignment]

    for connection in connections:
        xyz1 = connection[0]
        xyz2 = connection[1]
        x_pos = [xyz1[0], xyz2[0]]
        y_pos = [xyz1[1], xyz2[1]]
        z_pos = [xyz1[2], xyz2[2]]

        if alig and projection != "pca":
            if projection == "3d":
                ax.plot(x_pos, y_pos, z_pos, "k--", color="grey")
            else:
                ax.plot(x_pos, y_pos, "k--", color="grey")

    ax.set_aspect("equal")
    ax.legend()
    fig.tight_layout()

    if alig and projection == "pca":
        warnings.warn(
            "Error: not possible to show alignment in PCA projection !",
            UserWarning,
            stacklevel=2,
        )

    return distance, fig