import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.cluster import KMeans
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.metrics import adjusted_rand_score, silhouette_score

# Load the iris dataset
iris = load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = iris.target_names

# Convert to a DataFrame for easier manipulation
iris_df = pd.DataFrame(X, columns=feature_names)
iris_df['species'] = [target_names[i] for i in y]

# Scale the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# --- KNN Classification (Supervised) ---
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)

# Train a K-Nearest Neighbors classifier
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train, y_train)

# Make predictions with KNN
y_pred_knn = knn.predict(X_test)

# --- K-means Clustering (Unsupervised) ---
# Fit K-means with 3 clusters (same as number of species)
kmeans = KMeans(n_clusters=3, random_state=42, n_init=10)
cluster_labels = kmeans.fit_predict(X_scaled)

# --- Comparing Results ---
# For K-means, we need to map cluster labels to actual species
# This is tricky since cluster numbers are arbitrary
# Let's create a mapping based on the most common species in each cluster

# Create a mapping from cluster to dominant species
cluster_to_species = {}
for cluster_num in range(3):
    cluster_indices = np.where(cluster_labels == cluster_num)[0]
    species_in_cluster = y[cluster_indices]
    most_common_species = np.bincount(species_in_cluster).argmax()
    cluster_to_species[cluster_num] = most_common_species

# Map each cluster label to its most common species
mapped_clusters = np.array([cluster_to_species[label] for label in cluster_labels])

# --- Evaluation Metrics ---
# 1. KNN accuracy
knn_accuracy = accuracy_score(y_test, y_pred_knn)

# 2. K-means accuracy (after mapping clusters to species)
kmeans_accuracy = accuracy_score(y, mapped_clusters)

# 3. Adjusted Rand Index (measures similarity between clusterings)
ari_score = adjusted_rand_score(y, cluster_labels)

# 4. Silhouette score (measures how well-separated the clusters are)
silhouette = silhouette_score(X_scaled, cluster_labels)

# Print results
print("\n--- KNN Classification (Supervised) Results ---")
print(f"KNN Accuracy: {knn_accuracy:.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred_knn, target_names=target_names))

print("\n--- K-means Clustering (Unsupervised) Results ---")
print(f"K-means Accuracy (after mapping): {kmeans_accuracy:.4f}")
print(f"Adjusted Rand Index: {ari_score:.4f}")
print(f"Silhouette Score: {silhouette:.4f}")

print("\nCluster to Species Mapping:")
for cluster, species in cluster_to_species.items():
    print(f"Cluster {cluster} → {target_names[species]}")

# --- Visualization ---
# Create a PCA projection for visualization
from sklearn.decomposition import PCA
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)

# Plot both KNN and K-means results
plt.figure(figsize=(15, 6))

# Plot 1: True species
plt.subplot(1, 3, 1)
for i, species in enumerate(target_names):
    indices = y == i
    plt.scatter(X_pca[indices, 0], X_pca[indices, 1], label=species)
plt.title('True Species')
plt.xlabel('PCA Component 1')
plt.ylabel('PCA Component 2')
plt.legend()

# Plot 2: K-means clusters
plt.subplot(1, 3, 2)
for i in range(3):
    indices = cluster_labels == i
    plt.scatter(X_pca[indices, 0], X_pca[indices, 1], label=f'Cluster {i}')
plt.title('K-means Clusters')
plt.xlabel('PCA Component 1')
plt.ylabel('PCA Component 2')
plt.legend()

# Plot 3: K-means mapped to species 
plt.subplot(1, 3, 3)
for i, species in enumerate(target_names):
    indices = mapped_clusters == i
    plt.scatter(X_pca[indices, 0], X_pca[indices, 1], label=f'Mapped to {species}')
plt.title('K-means Mapped to Species')
plt.xlabel('PCA Component 1')
plt.ylabel('PCA Component 2')
plt.legend()

plt.tight_layout()
plt.savefig('knn_vs_kmeans.png')
print("\nVisualization saved as 'knn_vs_kmeans.png'")

# Save confusion matrix visualization
cm = confusion_matrix(y_test, y_pred_knn)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=target_names, yticklabels=target_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix for KNN Classification')
plt.tight_layout()
plt.savefig('confusion_matrix.png')
plt.close()

# Create and save pairplot
plt.figure(figsize=(12, 10))
sns.pairplot(iris_df, hue='species', markers=['o', 's', 'D'])
plt.suptitle('Pairplot of Iris Features by Species', y=1.02)
plt.savefig('pairplot.png')
plt.close()