69 lines
2.5 KiB
Python
69 lines
2.5 KiB
Python
from sklearn.tree import _tree
|
|
|
|
def extract_single_rule(tree, feature_names, class_to_extract=1):
|
|
"""
|
|
Extrahiert eine einzelne Regel aus einem Decision Tree für eine bestimmte Klasse.
|
|
|
|
Parameters:
|
|
-----------
|
|
tree : DecisionTreeClassifier
|
|
Der trainierte Entscheidungsbaum
|
|
feature_names : list
|
|
Liste der Feature-Namen
|
|
class_to_extract : int, default=1
|
|
Die Klasse, für die eine Regel extrahiert werden soll (0=≤50K, 1=>50K)
|
|
|
|
Returns:
|
|
--------
|
|
rule : str
|
|
Eine lesbare Regel als String
|
|
"""
|
|
tree_ = tree.tree_
|
|
|
|
# Funktion zum rekursiven Extrahieren einer Regel
|
|
def tree_to_rule(node, depth, conditions):
|
|
# Wenn wir einen Blattknoten erreicht haben
|
|
if tree_.children_left[node] == _tree.TREE_LEAF:
|
|
# Prüfe, ob dieser Blattknoten die gewünschte Klasse vorhersagt
|
|
if np.argmax(tree_.value[node][0]) == class_to_extract:
|
|
# Formatiere die Bedingungen als Regel
|
|
if conditions:
|
|
rule = " UND ".join(conditions)
|
|
return rule
|
|
else:
|
|
return "Keine Bedingungen (Wurzelklasse)"
|
|
return None
|
|
|
|
# Feature und Schwellenwert am aktuellen Knoten
|
|
feature = feature_names[tree_.feature[node]]
|
|
threshold = tree_.threshold[node]
|
|
|
|
# Linkspfad (≤)
|
|
left_conditions = conditions + [f"{feature} ≤ {threshold:.2f}"]
|
|
left_rule = tree_to_rule(tree_.children_left[node], depth + 1, left_conditions)
|
|
if left_rule is not None:
|
|
return left_rule
|
|
|
|
# Rechtspfad (>)
|
|
right_conditions = conditions + [f"{feature} > {threshold:.2f}"]
|
|
right_rule = tree_to_rule(tree_.children_right[node], depth + 1, right_conditions)
|
|
if right_rule is not None:
|
|
return right_rule
|
|
|
|
# Keine passende Regel gefunden
|
|
return None
|
|
|
|
# Starte die Suche vom Wurzelknoten
|
|
rule = tree_to_rule(0, 1, [])
|
|
|
|
# Formatiere die Ausgabe
|
|
class_name = "Einkommen > 50K" if class_to_extract == 1 else "Einkommen ≤ 50K"
|
|
if rule:
|
|
return f"WENN {rule} DANN {class_name}"
|
|
else:
|
|
return f"Keine Regel für {class_name} gefunden."
|
|
|
|
# Anwendung für die Extraktion einer Regel für hohes Einkommen (Klasse 1)
|
|
single_rule = extract_single_rule(surrogate_tree, X_train.columns.tolist(), class_to_extract=1)
|
|
print("Einzelne Regel aus dem Surrogate-Modell:")
|
|
print(single_rule) |