Files

67 lines
2.6 KiB
Python

import matplotlib.pyplot as plt
import numpy as np
def plot_surrogate_accuracy_vs_depth_test_only(rf_model, X_train, X_test, max_depths=range(1, 16)):
"""
Visualisiert die Genauigkeit des Surrogate-Modells für verschiedene Baumtiefen,
fokussiert nur auf die Testdaten.
"""
# Random Forest-Vorhersagen (nur einmal berechnen)
rf_train_predictions = rf_model.predict(X_train)
rf_test_predictions = rf_model.predict(X_test)
# Ergebnisse für verschiedene Baumtiefen
test_accuracies = []
for depth in max_depths:
# Surrogate-Baum mit aktueller Tiefe trainieren
surrogate_tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
surrogate_tree.fit(X_train, rf_train_predictions)
# Vorhersagen
surrogate_test_pred = surrogate_tree.predict(X_test)
# Genauigkeit berechnen
test_acc = np.mean(surrogate_test_pred == rf_test_predictions)
test_accuracies.append(test_acc)
# Visualisierung
plt.figure(figsize=(10, 6))
plt.plot(max_depths, test_accuracies, 'o-', color='#ED7D31', linewidth=2)
# Finde die beste Tiefe
best_depth = max_depths[np.argmax(test_accuracies)]
best_acc = max(test_accuracies)
# Markiere den besten Punkt
plt.scatter([best_depth], [best_acc], s=100, c='red', zorder=5)
plt.annotate(f'Optimale Tiefe: {best_depth}\nGenauigkeit: {best_acc:.4f}',
xy=(best_depth, best_acc), xytext=(best_depth+1, best_acc-0.05),
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5))
# Beschriftungen und Layout
plt.grid(alpha=0.3)
plt.title('Surrogate-Modell-Genauigkeit bei verschiedenen Baumtiefen', fontsize=14)
plt.xlabel('Maximale Baumtiefe', fontsize=12)
plt.ylabel('Genauigkeit auf Testdaten', fontsize=12)
# Füge Werte über den Punkten hinzu
for i, acc in enumerate(test_accuracies):
plt.text(max_depths[i], acc + 0.01, f'{acc:.3f}', ha='center')
# Y-Achse anpassen (je nach Daten)
y_min = max(0, min(test_accuracies) - 0.05)
plt.ylim(y_min, 1.05)
# Verbesserte visuelle Elemente
plt.fill_between(max_depths, test_accuracies, y_min, alpha=0.1, color='#ED7D31')
plt.tight_layout()
plt.savefig('output/surrogate_accuracy.png', dpi=300)
plt.show()
return best_depth, best_acc
# Aufruf der Funktion
best_depth, best_accuracy = plot_surrogate_accuracy_vs_depth_test_only(rf_model, X_train, X_test)
print(f"Optimale Baumtiefe: {best_depth} mit einer Genauigkeit von {best_accuracy:.4f}")