import matplotlib.pyplot as plt import numpy as np def plot_surrogate_accuracy_vs_depth_test_only(rf_model, X_train, X_test, max_depths=range(1, 16)): """ Visualisiert die Genauigkeit des Surrogate-Modells für verschiedene Baumtiefen, fokussiert nur auf die Testdaten. """ # Random Forest-Vorhersagen (nur einmal berechnen) rf_train_predictions = rf_model.predict(X_train) rf_test_predictions = rf_model.predict(X_test) # Ergebnisse für verschiedene Baumtiefen test_accuracies = [] for depth in max_depths: # Surrogate-Baum mit aktueller Tiefe trainieren surrogate_tree = DecisionTreeClassifier(max_depth=depth, random_state=42) surrogate_tree.fit(X_train, rf_train_predictions) # Vorhersagen surrogate_test_pred = surrogate_tree.predict(X_test) # Genauigkeit berechnen test_acc = np.mean(surrogate_test_pred == rf_test_predictions) test_accuracies.append(test_acc) # Visualisierung plt.figure(figsize=(10, 6)) plt.plot(max_depths, test_accuracies, 'o-', color='#ED7D31', linewidth=2) # Finde die beste Tiefe best_depth = max_depths[np.argmax(test_accuracies)] best_acc = max(test_accuracies) # Markiere den besten Punkt plt.scatter([best_depth], [best_acc], s=100, c='red', zorder=5) plt.annotate(f'Optimale Tiefe: {best_depth}\nGenauigkeit: {best_acc:.4f}', xy=(best_depth, best_acc), xytext=(best_depth+1, best_acc-0.05), arrowprops=dict(facecolor='black', shrink=0.05, width=1.5)) # Beschriftungen und Layout plt.grid(alpha=0.3) plt.title('Surrogate-Modell-Genauigkeit bei verschiedenen Baumtiefen', fontsize=14) plt.xlabel('Maximale Baumtiefe', fontsize=12) plt.ylabel('Genauigkeit auf Testdaten', fontsize=12) # Füge Werte über den Punkten hinzu for i, acc in enumerate(test_accuracies): plt.text(max_depths[i], acc + 0.01, f'{acc:.3f}', ha='center') # Y-Achse anpassen (je nach Daten) y_min = max(0, min(test_accuracies) - 0.05) plt.ylim(y_min, 1.05) # Verbesserte visuelle Elemente plt.fill_between(max_depths, test_accuracies, y_min, alpha=0.1, color='#ED7D31') plt.tight_layout() plt.savefig('output/surrogate_accuracy.png', dpi=300) plt.show() return best_depth, best_acc # Aufruf der Funktion best_depth, best_accuracy = plot_surrogate_accuracy_vs_depth_test_only(rf_model, X_train, X_test) print(f"Optimale Baumtiefe: {best_depth} mit einer Genauigkeit von {best_accuracy:.4f}")