intial commit (forked from private repo)
This commit is contained in:
67
extracted_cells/cell25.py
Normal file
67
extracted_cells/cell25.py
Normal file
@ -0,0 +1,67 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
def plot_surrogate_accuracy_vs_depth_test_only(rf_model, X_train, X_test, max_depths=range(1, 16)):
|
||||
"""
|
||||
Visualisiert die Genauigkeit des Surrogate-Modells für verschiedene Baumtiefen,
|
||||
fokussiert nur auf die Testdaten.
|
||||
"""
|
||||
# Random Forest-Vorhersagen (nur einmal berechnen)
|
||||
rf_train_predictions = rf_model.predict(X_train)
|
||||
rf_test_predictions = rf_model.predict(X_test)
|
||||
|
||||
# Ergebnisse für verschiedene Baumtiefen
|
||||
test_accuracies = []
|
||||
|
||||
for depth in max_depths:
|
||||
# Surrogate-Baum mit aktueller Tiefe trainieren
|
||||
surrogate_tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
|
||||
surrogate_tree.fit(X_train, rf_train_predictions)
|
||||
|
||||
# Vorhersagen
|
||||
surrogate_test_pred = surrogate_tree.predict(X_test)
|
||||
|
||||
# Genauigkeit berechnen
|
||||
test_acc = np.mean(surrogate_test_pred == rf_test_predictions)
|
||||
test_accuracies.append(test_acc)
|
||||
|
||||
# Visualisierung
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(max_depths, test_accuracies, 'o-', color='#ED7D31', linewidth=2)
|
||||
|
||||
# Finde die beste Tiefe
|
||||
best_depth = max_depths[np.argmax(test_accuracies)]
|
||||
best_acc = max(test_accuracies)
|
||||
|
||||
# Markiere den besten Punkt
|
||||
plt.scatter([best_depth], [best_acc], s=100, c='red', zorder=5)
|
||||
plt.annotate(f'Optimale Tiefe: {best_depth}\nGenauigkeit: {best_acc:.4f}',
|
||||
xy=(best_depth, best_acc), xytext=(best_depth+1, best_acc-0.05),
|
||||
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5))
|
||||
|
||||
# Beschriftungen und Layout
|
||||
plt.grid(alpha=0.3)
|
||||
plt.title('Surrogate-Modell-Genauigkeit bei verschiedenen Baumtiefen', fontsize=14)
|
||||
plt.xlabel('Maximale Baumtiefe', fontsize=12)
|
||||
plt.ylabel('Genauigkeit auf Testdaten', fontsize=12)
|
||||
|
||||
# Füge Werte über den Punkten hinzu
|
||||
for i, acc in enumerate(test_accuracies):
|
||||
plt.text(max_depths[i], acc + 0.01, f'{acc:.3f}', ha='center')
|
||||
|
||||
# Y-Achse anpassen (je nach Daten)
|
||||
y_min = max(0, min(test_accuracies) - 0.05)
|
||||
plt.ylim(y_min, 1.05)
|
||||
|
||||
# Verbesserte visuelle Elemente
|
||||
plt.fill_between(max_depths, test_accuracies, y_min, alpha=0.1, color='#ED7D31')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('output/surrogate_accuracy.png', dpi=300)
|
||||
plt.show()
|
||||
|
||||
return best_depth, best_acc
|
||||
|
||||
# Aufruf der Funktion
|
||||
best_depth, best_accuracy = plot_surrogate_accuracy_vs_depth_test_only(rf_model, X_train, X_test)
|
||||
print(f"Optimale Baumtiefe: {best_depth} mit einer Genauigkeit von {best_accuracy:.4f}")
|
||||
Reference in New Issue
Block a user