Skip to content

Replay strategies

import networkx as nx
from tqdm.auto import tqdm

from caveclient import CAVEclient
from paleo import (
    get_initial_graph,
    get_operations_level2_edits,
    get_nucleus_supervoxel,
    get_node_aliases,
    get_detailed_change_log,
)
root_id = 864691135639556411
client = CAVEclient("minnie65_public", version=1078)
change_log = get_detailed_change_log(root_id, client, filtered=False)
change_log
networkdeltas = get_operations_level2_edits(change_log.index, client)
initial_graph = get_initial_graph(root_id, client)
nuc_supervoxel_id = get_nucleus_supervoxel(root_id, client)
node_info = get_node_aliases(nuc_supervoxel_id, client, stop_layer=2)
anchor_nodes = node_info.index
from paleo import get_metaedits

metaedits, operation_map = get_metaedits(networkdeltas)
change_log["metaedit"] = change_log.index.map(operation_map)

metaedit_info = (
    change_log.groupby("metaedit")
    .agg({"timestamp": "min", "is_merge": "any", "metaedit": "count"})
    .rename(
        columns={
            "metaedit": "n_edits",
            "timestamp": "first_timestamp",
            "is_merge": "has_merge",
        }
    )
)
metaedit_info["operation_ids"] = (
    change_log.reset_index().groupby("metaedit")["operation_id"].unique()
)
metaedit_info = metaedit_info.sort_values("first_timestamp")
metaedit_info
metaedit_info.index.is_monotonic_increasing
from paleo import resolve_edit

graph = initial_graph.copy()

components = []
subgraphs = []

# remember to include the initial state
edits = metaedits.copy()
edits = {-1: None, **edits}

# after each edit, apply it and store the connected component for the nucleus node
for edit_id, delta in tqdm(edits.items(), disable=False):
    component = resolve_edit(graph, delta, node_info.index)
    components.append(component)

    subgraph = graph.subgraph(component).copy()
    subgraphs.append(subgraph)
final_edgelist = client.chunkedgraph.level2_chunk_graph(root_id)
final_graph = nx.from_edgelist(final_edgelist)
import numpy as np

print(np.setdiff1d(final_graph.nodes, subgraphs[-1].nodes))
used_l2_ids = np.unique(np.concatenate([list(c) for c in components]))

l2data = client.l2cache.get_l2data_table(used_l2_ids, split_columns=True)
l2data
import pandas as pd

from pcg_skel import pcg_skeleton_direct

nuc_table = client.info.get_datastack_info()["soma_table"]
nuc_info = client.materialize.query_table(
    nuc_table, filter_equal_dict=dict(pt_root_id=root_id)
)
nuc_loc = nuc_info["pt_position"].values[0]


def skeletonize(subgraph):
    component = pd.Index(list(subgraph.nodes))
    vertices = l2data.loc[
        component, ["rep_coord_nm_x", "rep_coord_nm_y", "rep_coord_nm_z"]
    ].values
    edges = nx.to_pandas_edgelist(subgraph).values
    edges = np.vectorize(component.get_loc)(edges)
    skeleton = pcg_skeleton_direct(vertices, edges, root_point=nuc_loc)
    return skeleton


skeletons = {}
for subgraph in tqdm(subgraphs, desc="Skeletonizing"):
    skeleton = skeletonize(subgraph)
    skeletons[subgraph] = skeleton


# graph = initial_graph.copy()
# # keep track of components that are reached as we go
# components = []
# # remember to include the initial state
# metaedits = {-1: None, **metaedits}

# skeletons = {}
# # after each edit, apply it and store the connected component for the nucleus node
# for edit_id, delta in tqdm(metaedits.items(), disable=False):
#     component = resolve_edit(graph, delta, node_info.index)
#     skeleton = skeletonize(graph, component)
#     skeletons[edit_id] = skeleton
import pyvista as pv

plotter = pv.Plotter()
plotter.open_gif("skeleton_evolution.gif", fps=30)


def skel_to_poly(skeleton):
    vertices = skeleton.vertices
    edges = skeleton.edges
    lines = np.full((len(edges), 3), 2)
    lines[:, 1:] = edges
    line_poly = pv.PolyData(vertices, lines=lines)
    return line_poly


last_skeleton = skeletons[list(skeletons.keys())[-1]]
actor = plotter.add_mesh(skel_to_poly(last_skeleton), color="black", line_width=2)
plotter.write_frame()
plotter.remove_actor(actor)

first_skeleton = skeletons[list(skeletons.keys())[0]]
actor = plotter.add_mesh(skel_to_poly(first_skeleton), color="blue", line_width=5)

for edit_id, skeleton in skeletons.items():
    line_poly = skel_to_poly(skeleton)
    actor = plotter.add_mesh(line_poly, color="black", line_width=2)
    plotter.write_frame()
    plotter.remove_actor(actor)

plotter.close()
import time
from pcg_skel import pcg_skeleton

currtime = time.time()
scratch_skeleton = pcg_skeleton(
    root_id,
    client,
    root_point=nuc_loc,
    root_point_resolution=[1, 1, 1],
)
from paleo import get_all_time_synapses

pre_synapses, post_synapses = get_all_time_synapses(root_id, client, verbose=True)
from paleo import get_nodes_aliases

supervoxel_ids = pre_synapses["pre_pt_supervoxel_id"].unique()
print(len(supervoxel_ids))
pre_l2_mappings = get_nodes_aliases(supervoxel_ids, client)

supervoxel_ids = post_synapses["post_pt_supervoxel_id"].unique()
print(len(supervoxel_ids))
# post_l2_mappings = get_nodes_aliases(supervoxel_ids, client)