Skip to content

Reproducing results on Sachs #6

@rmwu

Description

@rmwu

Hello! I've been trying to replicate the results on Sachs using the provided hyperparameters, but I'm getting SHD ~37-40 instead of the low 10s. Any clue why?

from cdt.data import load_dataset
data, graph = load_dataset("sachs")
data = data.to_numpy()
graph = nx.to_numpy_array(graph)

num_nodes = data.shape[1]
model = DiffAN(num_nodes, residue=True)
pred_graph, order = model.fit(data)

metrics = MetricsDAG(pred_graph, graph).metrics

This produces

{"fdr": 0.8919, "tpr": 0.2222, "fpr": 0.8919, "shd": 37, "nnz": 37, "precision": 0.1081, "recall": 0.2222, "F1": 0.1455, "gscore": 0.0}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions