40 lines
1.7 KiB
Python
40 lines
1.7 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.linear_model import LinearRegression
|
|
from sklearn.metrics.pairwise import euclidean_distances
|
|
|
|
# Anzahl der zu erzeugenden Perturbationen
|
|
num_samples = 500
|
|
|
|
# Die Originalinstanz, die wir erklären wollen
|
|
original_instance = instance_df.iloc[0].copy()
|
|
|
|
# Erstelle perturbierte Instanzen durch zufällige Variation aller Features
|
|
perturbed_instances = pd.DataFrame(
|
|
np.random.normal(loc=original_instance, scale=1.0, size=(num_samples, len(original_instance))),
|
|
columns=original_instance.index
|
|
)
|
|
|
|
# Vorhersagen für die perturbierten Instanzen mit dem Random-Forest-Modell
|
|
perturbed_instances["prediction"] = best_rf_model.predict_proba(perturbed_instances)[:, 1]
|
|
|
|
# Berechnung der Gewichte nach Distanz zur Originalinstanz
|
|
distances = euclidean_distances(perturbed_instances.drop(columns=["prediction"]), [original_instance])
|
|
kernel_width = np.sqrt(len(original_instance)) # Kernel-Bandbreite
|
|
weights = np.exp(- (distances ** 2) / (2 * (kernel_width ** 2)))
|
|
|
|
# Gewichtete lokale lineare Regression zum Erklären der Vorhersage
|
|
lin_reg = LinearRegression()
|
|
lin_reg.fit(perturbed_instances.drop(columns=["prediction"]), perturbed_instances["prediction"], sample_weight=weights.flatten())
|
|
|
|
# Anzeige der Feature-Wichtigkeiten
|
|
feature_importances = pd.Series(lin_reg.coef_, index=original_instance.index).sort_values(key=abs, ascending=False)
|
|
|
|
# Visualisierung der wichtigsten Features
|
|
plt.figure(figsize=(8, 5))
|
|
feature_importances[:10].plot(kind="barh", color="skyblue")
|
|
plt.xlabel("Einfluss auf die Vorhersage")
|
|
plt.title("Erklärungsmodell (Nachbildung von LIME)")
|
|
plt.gca().invert_yaxis()
|
|
plt.show() |