Aller au contenu

Open In Colab

4. Détection d'images similaires

Objectif

Prédire le diagnostic du patient

Donnés

Les données proviennent de Kaggle

Méthologie

SVM

Implémentation

Code source

Références

Webinar Workflow ML

Objectif: Prédire le diagnostic du patient

Donnés: Les données proviennent de Kaggle

Méthologie: SVM

Implémentation

  1. Importation des données
  2. Exploration
  3. Conversion des variables catégorielles en numérique
  4. Séparation du jeu de données
  5. Entraînement
  6. Sélection de modèle

Voici les principaux outils que nous utilisons pour l'implémentation - Python - Pandas - Scitkit-learn

Librairies

import pickle

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.svm import LinearSVC, SVC
from sklearn.ensemble import RandomForestClassifier, BaggingClassifier, GradientBoostingClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
sns.set()
sns.set_theme(style="white")
# Reproductibility
np.random.seed(42)

Importation des données

# Read data `Prostate_Cancer.csv` from my github. Original dataset come from Kaggle [https://www.kaggle.com/sajidsaifi/prostate-cancer]
try:
    data = pd.read_csv("Prostate_Cancer.csv")
except:
    data = pd.read_csv('https://raw.githubusercontent.com/joekakone/datasets/master/datasets/Prostate_Cancer.csv')
# Show the 10 first rows
data.head(10)
id diagnosis_result radius texture perimeter area smoothness compactness symmetry fractal_dimension
0 1 M 23 12 151 954 0.143 0.278 0.242 0.079
1 2 B 9 13 133 1326 0.143 0.079 0.181 0.057
2 3 M 21 27 130 1203 0.125 0.160 0.207 0.060
3 4 M 14 16 78 386 0.070 0.284 0.260 0.097
4 5 M 9 19 135 1297 0.141 0.133 0.181 0.059
5 6 B 25 25 83 477 0.128 0.170 0.209 0.076
6 7 M 16 26 120 1040 0.095 0.109 0.179 0.057
7 8 M 15 18 90 578 0.119 0.165 0.220 0.075
8 9 M 19 24 88 520 0.127 0.193 0.235 0.074
9 10 M 25 11 84 476 0.119 0.240 0.203 0.082
data.shape
(100, 10)

Le tableau contient 100 lignes et 10 colonnes. La première colonne id représente les identifiants des patient, elle ne nous sera pas utile dans notre travail, nous allons l'ignorer dans la suite. La colonnes diagnosis_result représnet quant à elle le résultat du diagnostic du patient, c'est cette valeur que nous allons prédire. Les autres colonnes décrivent l'état du patient, elles nous serviront çà prédire lle diagnostic du patient.

# Remove `id` column
data.drop(["id"], axis=1, inplace=True)

Nettoyage

data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 100 entries, 0 to 99
Data columns (total 9 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   diagnosis_result   100 non-null    object 
 1   radius             100 non-null    int64  
 2   texture            100 non-null    int64  
 3   perimeter          100 non-null    int64  
 4   area               100 non-null    int64  
 5   smoothness         100 non-null    float64
 6   compactness        100 non-null    float64
 7   symmetry           100 non-null    float64
 8   fractal_dimension  100 non-null    float64
dtypes: float64(4), int64(4), object(1)
memory usage: 7.2+ KB

np.sum(data.isna())
diagnosis_result     0
radius               0
texture              0
perimeter            0
area                 0
smoothness           0
compactness          0
symmetry             0
fractal_dimension    0
dtype: int64

Dans notre tableau, il n'y a pas de données manquantes. Généralement ce n'est pas le cas et il faudra corriger cela.

Exploration

Distribution de la variable objectif

diagnosis_result = data.diagnosis_result.value_counts()
diagnosis_result
M    62
B    38
Name: diagnosis_result, dtype: int64
plt.figure(figsize=(8, 6))
diagnosis_result.plot(kind='bar')
plt.title("Distribution des résulats de diagnostic")
plt.show()

Distribution des variables explicatives

plt.figure(figsize=(8, 6))
sns.boxplot(data=data, orient="h")
plt.title("Distribution des variables explicatives")
plt.show()

Distribution des variables explicatives par la variable objectif

vars = data.groupby(by="diagnosis_result").mean()
vars
radius texture perimeter area smoothness compactness symmetry fractal_dimension
diagnosis_result
B 17.947368 17.763158 78.500000 474.342105 0.099053 0.086895 0.184053 0.064605
M 16.177419 18.516129 107.983871 842.951613 0.104984 0.151097 0.198758 0.064742

On constate une grande variation de perimeter et area en fonction de diagnosis_result

fig = plt.figure(figsize=(16, 6))

fig.add_subplot(1, 2, 1)
sns.distplot(x=data[data["diagnosis_result"]=="M"]["perimeter"])
sns.distplot(x=data[data["diagnosis_result"]=="B"]["perimeter"])
plt.title("perimeter")

fig.add_subplot(1, 2, 2)
sns.distplot(x=data[data["diagnosis_result"]=="M"]["area"])
sns.distplot(x=data[data["diagnosis_result"]=="B"]["area"])
plt.yticks([])
plt.title("area")

plt.show()
/usr/local/lib/python3.7/dist-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/usr/local/lib/python3.7/dist-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/usr/local/lib/python3.7/dist-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)
/usr/local/lib/python3.7/dist-packages/seaborn/distributions.py:2557: FutureWarning: `distplot` is a deprecated function and will be removed in a future version. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms).
  warnings.warn(msg, FutureWarning)

Reagrdons ce qu'il en est des correlations éventuelles entre les variables explicatives.

plt.figure(figsize=(10, 8))
corr = data.drop(["diagnosis_result"], axis=1).corr()
sns.heatmap(corr, annot=True)
plt.title("Matrice de corrélation")
plt.show()
sns.pairplot(data=data, hue="diagnosis_result")
plt.show()