Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """ | |
| Render a JSON-aware visualization of CAIS's rule-based method selector. | |
| - Parses a CAIS run payload (dict) and highlights ALL plausible candidates (green). | |
| - The actually selected method receives a thicker border. | |
| - The traversed decision path edges are colored. | |
| Usage: | |
| render_from_json(payload_dict, out_stem="artifacts/decision_tree") | |
| (Optional) CLI: | |
| python decision_tree.py payload.json | |
| """ | |
| from graphviz import Digraph | |
| import json, sys | |
| from typing import Dict, Any, List, Set, Tuple, Optional | |
| from auto_causal.components.decision_tree import ( | |
| DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY, | |
| INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING, | |
| GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT | |
| ) | |
| LABEL = { | |
| DIFF_IN_MEANS: "Diff-in-Means (RCT)", | |
| LINEAR_REGRESSION: "Linear Regression", | |
| DIFF_IN_DIFF: "Difference-in-Differences", | |
| REGRESSION_DISCONTINUITY: "Regression Discontinuity", | |
| INSTRUMENTAL_VARIABLE: "Instrumental Variables", | |
| PROPENSITY_SCORE_MATCHING: "PS Matching", | |
| PROPENSITY_SCORE_WEIGHTING: "PS Weighting", | |
| GENERALIZED_PROPENSITY_SCORE: "Generalized PS (continuous T)", | |
| BACKDOOR_ADJUSTMENT: "Backdoor Adjustment", | |
| FRONTDOOR_ADJUSTMENT: "Frontdoor Adjustment", | |
| } | |
| # -------- Heuristic extractors from payload -------- # | |
| def _get(d: Dict, path: List[str], default=None): | |
| cur = d | |
| for k in path: | |
| if not isinstance(cur, dict) or k not in cur: | |
| return default | |
| cur = cur[k] | |
| return cur | |
| def extract_signals(p: Dict[str, Any]) -> Dict[str, Any]: | |
| vars_ = _get(p, ["results", "variables"], {}) or _get(p, ["variables"], {}) or {} | |
| da = _get(p, ["results", "dataset_analysis"], {}) or _get(p, ["dataset_analysis"], {}) or {} | |
| treatment = vars_.get("treatment_variable") | |
| t_type = vars_.get("treatment_variable_type") # "binary"/"continuous" | |
| is_rct = bool(vars_.get("is_rct", False)) | |
| # Temporal / panel | |
| temporal_detected = bool(da.get("temporal_structure_detected", False)) | |
| time_var = vars_.get("time_variable") | |
| group_var = vars_.get("group_variable") | |
| has_temporal = temporal_detected or bool(time_var) or bool(group_var) | |
| # RDD | |
| running_variable = vars_.get("running_variable") | |
| cutoff_value = vars_.get("cutoff_value") | |
| rdd_ready = running_variable is not None and cutoff_value is not None | |
| # (Some detectors raise 'discontinuities_detected', but we still require running var + cutoff.) | |
| # If you want permissive behavior, flip rdd_ready to also consider da.get("discontinuities_detected"). | |
| # Instruments | |
| instrument = vars_.get("instrument_variable") | |
| pot_instr = da.get("potential_instruments") or [] | |
| # Consider an instrument valid only if it exists and is NOT the treatment itself | |
| has_valid_instrument = ( | |
| instrument is not None and instrument != treatment | |
| ) or any(pi and pi != treatment for pi in pot_instr) | |
| covariates = vars_.get("covariates") or [] | |
| has_covariates = len(covariates) > 0 | |
| # Frontdoor: only mark if explicitly provided (else too speculative) | |
| frontdoor_ok = bool(_get(p, ["results", "dataset_analysis", "frontdoor_satisfied"], False)) | |
| # Overlap: if explicitly known, use it; else unknown → both PS variants remain plausible. | |
| overlap_assessment = da.get("overlap_assessment") | |
| strong_overlap = None | |
| if isinstance(overlap_assessment, dict): | |
| # accept typical keys like {"strong_overlap": true} | |
| strong_overlap = overlap_assessment.get("strong_overlap") | |
| return dict( | |
| treatment=treatment, | |
| t_type=t_type, | |
| is_rct=is_rct, | |
| has_temporal=has_temporal, | |
| rdd_ready=rdd_ready, | |
| has_valid_instrument=has_valid_instrument, | |
| has_covariates=has_covariates, | |
| frontdoor_ok=frontdoor_ok, | |
| strong_overlap=strong_overlap, | |
| ) | |
| # -------- Candidate inference (green leaves) -------- # | |
| def infer_candidate_methods(signals: Dict[str, Any]) -> Set[str]: | |
| cands: Set[str] = set() | |
| is_rct = signals["is_rct"] | |
| # RCT branch: both Diff-in-Means and LR are valid analyses; IV only if a valid instrument exists (e.g., randomized encouragement) | |
| if is_rct: | |
| cands.add(DIFF_IN_MEANS) | |
| if signals["has_covariates"]: | |
| cands.add(LINEAR_REGRESSION) | |
| if signals["has_valid_instrument"]: | |
| cands.add(INSTRUMENTAL_VARIABLE) | |
| return cands # stop here; the observational tree is not needed | |
| # Observational branch | |
| if signals["has_temporal"]: | |
| cands.add(DIFF_IN_DIFF) | |
| if signals["rdd_ready"]: | |
| cands.add(REGRESSION_DISCONTINUITY) | |
| if signals["has_valid_instrument"]: | |
| cands.add(INSTRUMENTAL_VARIABLE) | |
| if signals["frontdoor_ok"]: | |
| cands.add(FRONTDOOR_ADJUSTMENT) | |
| # Treatment type | |
| if str(signals["t_type"]).lower() == "continuous": | |
| cands.add(GENERALIZED_PROPENSITY_SCORE) | |
| # Backdoor / PS (need covariates) | |
| if signals["has_covariates"]: | |
| # If overlap is known, choose one; if unknown, mark both as plausible. | |
| if signals["strong_overlap"] is True: | |
| cands.add(PROPENSITY_SCORE_MATCHING) | |
| elif signals["strong_overlap"] is False: | |
| cands.add(PROPENSITY_SCORE_WEIGHTING) | |
| else: | |
| cands.add(PROPENSITY_SCORE_MATCHING) | |
| cands.add(PROPENSITY_SCORE_WEIGHTING) | |
| cands.add(BACKDOOR_ADJUSTMENT) | |
| return cands | |
| # -------- Compute the single realized path to the chosen leaf (for edge coloring) -------- # | |
| def infer_decision_path(signals: Dict[str, Any], selected_method: Optional[str]) -> List[Tuple[str, str]]: | |
| path: List[Tuple[str, str]] = [] | |
| # Start → is_rct | |
| path.append(("start", "is_rct")) | |
| if signals["is_rct"]: | |
| path.append(("is_rct", "has_instr_rct")) | |
| if signals["has_valid_instrument"]: | |
| path.append(("has_instr_rct", INSTRUMENTAL_VARIABLE)) | |
| else: | |
| path.append(("has_instr_rct", "has_cov_rct")) | |
| if signals["has_covariates"]: | |
| path.append(("has_cov_rct", LINEAR_REGRESSION)) | |
| else: | |
| path.append(("has_cov_rct", DIFF_IN_MEANS)) | |
| return path | |
| # Observational | |
| path.append(("is_rct", "has_temporal")) | |
| if signals["has_temporal"]: | |
| path.append(("has_temporal", DIFF_IN_DIFF)) | |
| return path | |
| else: | |
| path.append(("has_temporal", "has_rv")) | |
| if signals["rdd_ready"]: | |
| path.append(("has_rv", REGRESSION_DISCONTINUITY)) | |
| return path | |
| else: | |
| path.append(("has_rv", "has_instr")) | |
| if signals["has_valid_instrument"]: | |
| path.append(("has_instr", INSTRUMENTAL_VARIABLE)) | |
| return path | |
| else: | |
| path.append(("has_instr", "frontdoor")) | |
| if signals["frontdoor_ok"]: | |
| path.append(("frontdoor", FRONTDOOR_ADJUSTMENT)) | |
| return path | |
| else: | |
| path.append(("frontdoor", "t_cont")) | |
| if str(signals["t_type"]).lower() == "continuous": | |
| path.append(("t_cont", GENERALIZED_PROPENSITY_SCORE)) | |
| return path | |
| else: | |
| path.append(("t_cont", "has_cov")) | |
| if signals["has_covariates"]: | |
| path.append(("has_cov", "overlap")) | |
| # If overlap known, pick the branch; else default to weighting. | |
| if signals["strong_overlap"] is True: | |
| path.append(("overlap", PROPENSITY_SCORE_MATCHING)) | |
| else: | |
| path.append(("overlap", PROPENSITY_SCORE_WEIGHTING)) | |
| else: | |
| path.append(("has_cov", BACKDOOR_ADJUSTMENT)) # keep original topology; see note in previous message | |
| return path | |
| # -------- Graph building -------- # | |
| def build_graph(payload: Dict[str, Any]) -> Digraph: | |
| g = Digraph("CAISDecisionTree", format="svg") | |
| g.attr(rankdir="LR", nodesep="0.4", ranksep="0.35", fontsize="11") | |
| # Decisions | |
| g.node("start", "Start", shape="circle") | |
| g.node("is_rct", "Is RCT?", shape="diamond") | |
| g.node("has_instr_rct", "Instrument available?", shape="diamond") | |
| g.node("has_cov_rct", "Covariates observed?", shape="diamond") | |
| g.node("has_temporal", "Temporal structure?", shape="diamond") | |
| g.node("has_rv", "Running var & cutoff?", shape="diamond") | |
| g.node("has_instr", "Instrument available?", shape="diamond") | |
| g.node("frontdoor", "Frontdoor criterion satisfied?", shape="diamond") | |
| g.node("has_cov", "Covariates observed?", shape="diamond") | |
| g.node("overlap", "Strong overlap?\n(overlap ≥ 0.1)", shape="diamond") | |
| g.node("t_cont", "Treatment continuous?", shape="diamond") | |
| # Leaves | |
| def leaf(name_const, fill=None, bold=False): | |
| attrs = {"shape": "box", "style": "rounded"} | |
| if fill: | |
| attrs.update(style="rounded,filled", fillcolor=fill) | |
| if bold: | |
| attrs.update(penwidth="2") | |
| g.node(name_const, LABEL[name_const], **attrs) | |
| # Compute signals, candidates, path | |
| signals = extract_signals(payload) | |
| candidates = infer_candidate_methods(signals) | |
| selected_method_str = _get(payload, ["results", "results", "method_used"]) \ | |
| or _get(payload, ["results", "method_used"]) \ | |
| or _get(payload, ["method"]) | |
| selected_method = { | |
| "linear_regression": LINEAR_REGRESSION, | |
| "diff_in_means": DIFF_IN_MEANS, | |
| "difference_in_differences": DIFF_IN_DIFF, | |
| "regression_discontinuity": REGRESSION_DISCONTINUITY, | |
| "instrumental_variable": INSTRUMENTAL_VARIABLE, | |
| "propensity_score_matching": PROPENSITY_SCORE_MATCHING, | |
| "propensity_score_weighting": PROPENSITY_SCORE_WEIGHTING, | |
| "generalized_propensity_score": GENERALIZED_PROPENSITY_SCORE, | |
| "backdoor_adjustment": BACKDOOR_ADJUSTMENT, | |
| "frontdoor_adjustment": FRONTDOOR_ADJUSTMENT, | |
| }.get(str(selected_method_str or "").lower()) | |
| # Add leaves with coloring | |
| for m in [ | |
| DIFF_IN_MEANS, LINEAR_REGRESSION, DIFF_IN_DIFF, REGRESSION_DISCONTINUITY, | |
| INSTRUMENTAL_VARIABLE, PROPENSITY_SCORE_MATCHING, PROPENSITY_SCORE_WEIGHTING, | |
| GENERALIZED_PROPENSITY_SCORE, BACKDOOR_ADJUSTMENT, FRONTDOOR_ADJUSTMENT | |
| ]: | |
| leaf(m, | |
| fill=("palegreen" if m in candidates else None), | |
| bold=(m == selected_method)) | |
| # Edges with optional path highlighting | |
| path_edges = set(infer_decision_path(signals, selected_method)) | |
| def e(u, v, label=None): | |
| attrs = {} | |
| if (u, v) in path_edges: | |
| attrs.update(color="forestgreen", penwidth="2") | |
| g.edge(u, v, **({} if label is None else {"label": label}) | attrs) | |
| # Topology (unchanged) | |
| e("start", "is_rct") | |
| # RCT branch | |
| e("is_rct", "has_instr_rct", label="Yes") | |
| e("has_instr_rct", INSTRUMENTAL_VARIABLE, label="Yes") | |
| e("has_instr_rct", "has_cov_rct", label="No") | |
| e("has_cov_rct", LINEAR_REGRESSION, label="Yes") | |
| e("has_cov_rct", DIFF_IN_MEANS, label="No") | |
| # Observational branch | |
| e("is_rct", "has_temporal", label="No") | |
| e("has_temporal", DIFF_IN_DIFF, label="Yes") | |
| e("has_temporal", "has_rv", label="No") | |
| e("has_rv", REGRESSION_DISCONTINUITY, label="Yes") | |
| e("has_rv", "has_instr", label="No") | |
| e("has_instr", INSTRUMENTAL_VARIABLE, label="Yes") | |
| e("has_instr", "frontdoor", label="No") | |
| e("frontdoor", FRONTDOOR_ADJUSTMENT, label="Yes") | |
| e("frontdoor", "t_cont", label="No") | |
| e("t_cont", GENERALIZED_PROPENSITY_SCORE, label="Yes") | |
| e("t_cont", "has_cov", label="No") | |
| e("has_cov", "overlap", label="Yes") | |
| e("has_cov", BACKDOOR_ADJUSTMENT, label="No") | |
| e("overlap", PROPENSITY_SCORE_MATCHING, label="Yes") | |
| e("overlap", PROPENSITY_SCORE_WEIGHTING, label="No") | |
| # Optional legend | |
| g.node("legend", "Legend:\nGreen = plausible candidate(s)\nBold border = method used", shape="note") | |
| g.edge("legend", "start", style="dashed", arrowhead="none") | |
| return g | |
| def render_from_json(payload: Dict[str, Any], out_stem: str = "artifacts/decision_tree"): | |
| g = build_graph(payload) | |
| g.save(filename=f"{out_stem}.dot") | |
| g.render(filename=out_stem, cleanup=True) # SVG | |
| g.format = "png" | |
| g.render(filename=out_stem, cleanup=True) # PNG | |
| def main(): | |
| # if len(sys.argv) >= 2: | |
| with open('sample_output.json', "r") as f: | |
| payload = json.load(f) | |
| # else: | |
| # payload = json.load() | |
| render_from_json(payload) | |
| if __name__ == "__main__": | |
| main() | |