Files
ExplainableAI/extracted_cells/cell24.py

69 lines
2.5 KiB
Python

from sklearn.tree import _tree
def extract_single_rule(tree, feature_names, class_to_extract=1):
"""
Extrahiert eine einzelne Regel aus einem Decision Tree für eine bestimmte Klasse.
Parameters:
-----------
tree : DecisionTreeClassifier
Der trainierte Entscheidungsbaum
feature_names : list
Liste der Feature-Namen
class_to_extract : int, default=1
Die Klasse, für die eine Regel extrahiert werden soll (0=≤50K, 1=>50K)
Returns:
--------
rule : str
Eine lesbare Regel als String
"""
tree_ = tree.tree_
# Funktion zum rekursiven Extrahieren einer Regel
def tree_to_rule(node, depth, conditions):
# Wenn wir einen Blattknoten erreicht haben
if tree_.children_left[node] == _tree.TREE_LEAF:
# Prüfe, ob dieser Blattknoten die gewünschte Klasse vorhersagt
if np.argmax(tree_.value[node][0]) == class_to_extract:
# Formatiere die Bedingungen als Regel
if conditions:
rule = " UND ".join(conditions)
return rule
else:
return "Keine Bedingungen (Wurzelklasse)"
return None
# Feature und Schwellenwert am aktuellen Knoten
feature = feature_names[tree_.feature[node]]
threshold = tree_.threshold[node]
# Linkspfad (≤)
left_conditions = conditions + [f"{feature}{threshold:.2f}"]
left_rule = tree_to_rule(tree_.children_left[node], depth + 1, left_conditions)
if left_rule is not None:
return left_rule
# Rechtspfad (>)
right_conditions = conditions + [f"{feature} > {threshold:.2f}"]
right_rule = tree_to_rule(tree_.children_right[node], depth + 1, right_conditions)
if right_rule is not None:
return right_rule
# Keine passende Regel gefunden
return None
# Starte die Suche vom Wurzelknoten
rule = tree_to_rule(0, 1, [])
# Formatiere die Ausgabe
class_name = "Einkommen > 50K" if class_to_extract == 1 else "Einkommen ≤ 50K"
if rule:
return f"WENN {rule} DANN {class_name}"
else:
return f"Keine Regel für {class_name} gefunden."
# Anwendung für die Extraktion einer Regel für hohes Einkommen (Klasse 1)
single_rule = extract_single_rule(surrogate_tree, X_train.columns.tolist(), class_to_extract=1)
print("Einzelne Regel aus dem Surrogate-Modell:")
print(single_rule)