67 lines
2.6 KiB
Python
67 lines
2.6 KiB
Python
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
|
|
def plot_surrogate_accuracy_vs_depth_test_only(rf_model, X_train, X_test, max_depths=range(1, 16)):
|
|
"""
|
|
Visualisiert die Genauigkeit des Surrogate-Modells für verschiedene Baumtiefen,
|
|
fokussiert nur auf die Testdaten.
|
|
"""
|
|
# Random Forest-Vorhersagen (nur einmal berechnen)
|
|
rf_train_predictions = rf_model.predict(X_train)
|
|
rf_test_predictions = rf_model.predict(X_test)
|
|
|
|
# Ergebnisse für verschiedene Baumtiefen
|
|
test_accuracies = []
|
|
|
|
for depth in max_depths:
|
|
# Surrogate-Baum mit aktueller Tiefe trainieren
|
|
surrogate_tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
|
|
surrogate_tree.fit(X_train, rf_train_predictions)
|
|
|
|
# Vorhersagen
|
|
surrogate_test_pred = surrogate_tree.predict(X_test)
|
|
|
|
# Genauigkeit berechnen
|
|
test_acc = np.mean(surrogate_test_pred == rf_test_predictions)
|
|
test_accuracies.append(test_acc)
|
|
|
|
# Visualisierung
|
|
plt.figure(figsize=(10, 6))
|
|
plt.plot(max_depths, test_accuracies, 'o-', color='#ED7D31', linewidth=2)
|
|
|
|
# Finde die beste Tiefe
|
|
best_depth = max_depths[np.argmax(test_accuracies)]
|
|
best_acc = max(test_accuracies)
|
|
|
|
# Markiere den besten Punkt
|
|
plt.scatter([best_depth], [best_acc], s=100, c='red', zorder=5)
|
|
plt.annotate(f'Optimale Tiefe: {best_depth}\nGenauigkeit: {best_acc:.4f}',
|
|
xy=(best_depth, best_acc), xytext=(best_depth+1, best_acc-0.05),
|
|
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5))
|
|
|
|
# Beschriftungen und Layout
|
|
plt.grid(alpha=0.3)
|
|
plt.title('Surrogate-Modell-Genauigkeit bei verschiedenen Baumtiefen', fontsize=14)
|
|
plt.xlabel('Maximale Baumtiefe', fontsize=12)
|
|
plt.ylabel('Genauigkeit auf Testdaten', fontsize=12)
|
|
|
|
# Füge Werte über den Punkten hinzu
|
|
for i, acc in enumerate(test_accuracies):
|
|
plt.text(max_depths[i], acc + 0.01, f'{acc:.3f}', ha='center')
|
|
|
|
# Y-Achse anpassen (je nach Daten)
|
|
y_min = max(0, min(test_accuracies) - 0.05)
|
|
plt.ylim(y_min, 1.05)
|
|
|
|
# Verbesserte visuelle Elemente
|
|
plt.fill_between(max_depths, test_accuracies, y_min, alpha=0.1, color='#ED7D31')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig('output/surrogate_accuracy.png', dpi=300)
|
|
plt.show()
|
|
|
|
return best_depth, best_acc
|
|
|
|
# Aufruf der Funktion
|
|
best_depth, best_accuracy = plot_surrogate_accuracy_vs_depth_test_only(rf_model, X_train, X_test)
|
|
print(f"Optimale Baumtiefe: {best_depth} mit einer Genauigkeit von {best_accuracy:.4f}") |