Wykorzystanie DeepShap do zrozumienia i udoskonalenia modelu napędzającego samochód autonomiczny
Przerażają mnie autonomiczne samochody. Wielkie kawałki metalu latają bez ludzi, którzy mogliby je powstrzymać, jeśli coś pójdzie nie tak. Aby zmniejszyć to ryzyko, nie wystarczy ocenić modele napędzające te bestie. Musimy także zrozumieć, w jaki sposób prognozują. Ma to na celu uniknięcie wszelkich przypadków Edge, które mogłyby spowodować nieprzewidziane wypadki.
OK, więc nasza aplikacja nie jest tak konsekwentna. Będziemy debugować model zastosowany do napędu minisamochodu (najgorsze, czego można się spodziewać, to stłuczona kostka). Mimo to metody IML mogą być przydatne. Zobaczymy, jak mogą jeszcze poprawić wydajność modelu.
W szczególności:
- Dostosuj ResNet-18 za pomocą PyTorch z danymi obrazu i ciągłą zmienną docelową
- Oceń model za pomocą MSE i wykresów punktowych
- Zinterpretuj model za pomocą DeepSHAP
- Popraw model poprzez lepsze gromadzenie danych
- Omów, w jaki sposób powiększenie obrazu mogłoby jeszcze bardziej ulepszyć model
Po drodze omówimy kilka kluczowych fragmentów kodu Pythona. Pełny projekt można także znaleźć na GitHub.
Jeśli dopiero zaczynasz korzystać z SHAP, obejrzyj filmponiżej. Jeśli chcesz więcej, sprawdź mój Kurs SHAP. Możesz uzyskać bezpłatny dostęp, jeśli zapiszesz się na mój Biuletyn :)
Pakiety Pythona
# Imports import numpy as np import pandas as pd import matplotlib.pyplot as plt import glob import random from PIL import Image import cv2 import torch import torchvision from torchvision import transforms from torch.utils.data import DataLoader import shap from sklearn.metrics import mean_squared_error
Zbiór danych
Projekt zaczynamy od zebrania danych tylko w jednym pomieszczeniu (będzie to nas prześladować). Jak wspomniano, używamy obrazów do zasilania zautomatyzowanego samochodu. Przykłady można znaleźć na Kaggle. Wszystkie te obrazy mają wymiary 224 x 224 pikseli.
Wyświetlamy jeden z nich za pomocą poniższego kodu. Zanotuj nazwę obrazu (wiersz 2). Pierwsze dwie liczby to współrzędne x i y w ramce 224 x 224. Na Rysunku 1 widać, że współrzędne te zostały wyświetlone w zielonym kółku (linia 8).
#Load example image name = "32_50_c78164b4-40d2-11ed-a47b-a46bb6070c92.jpg" x = int(name.split("_")[0]) y = int(name.split("_")[1]) img = Image.open("../data/room_1/" + name) img = np.array(img) cv2.circle(img, (x, y), 8, (0, 255, 0), 3) plt.imshow(img)
Te współrzędne są zmienną docelową. Model przewiduje je, wykorzystując obraz jako dane wejściowe. Ta prognoza jest następnie wykorzystywana do kierowania samochodem. W tym przypadku widać, że samochód zbliża się do skrętu w lewo. Idealnym kierunkiem jest podążanie w kierunku współrzędnych podanych przez zielone kółko.
Trenowanie modelu PyTorch
Chcę się skupić na SHAP, więc nie będziemy zagłębiać się zbytnio w kod modelowania. Jeśli masz jakieś pytania, możesz je zadać w komentarzach.
Zaczynamy od utworzenia klasy ImageDataset. Służy do ładowania naszych danych obrazu i zmiennych docelowych. Robi to za pomocą ścieżek do naszych obrazów. Należy zwrócić uwagę na sposób skalowania zmiennych docelowych — zarówno x, jak i y będą wynosić od -1 do 1 mocny>.
class ImageDataset(torch.utils.data.Dataset): def __init__(self, paths, transform): self.transform = transform self.paths = paths def __getitem__(self, idx): """Get image and target (x, y) coordinates""" # Read image path = self.paths[idx] image = cv2.imread(path, cv2.IMREAD_COLOR) image = Image.fromarray(image) # Transform image image = self.transform(image) # Get target target = self.get_target(path) target = torch.Tensor(target) return image, target def get_target(self,path): """Get the target (x, y) coordinates from path""" name = os.path.basename(path) items = name.split('_') x = items[0] y = items[1] # Scale between -1 and 1 x = 2.0 * (int(x)/ 224 - 0.5) # -1 left, +1 right y = 2.0 * (int(y) / 244 -0.5)# -1 top, +1 bottom return [x, y] def __len__(self): return len(self.paths)
W rzeczywistości, gdy model jest wdrażany, do kierowania samochodem wykorzystywane są tylko przewidywania x. Ze względu na skalowanie znak przewidywania x określi kierunek samochodu. Gdy x ‹ 0, samochód powinien skręcić w lewo. Podobnie, gdy x › 0 samochód powinien skręcić w prawo. Im większa wartość x, tym ostrzejszy skręt.
Klasę ImageDataset wykorzystujemy do tworzenia modułów ładujących dane szkoleniowe i walidacyjne. Odbywa się to poprzez losowy podział 80/20 wszystkich ścieżek obrazu z pokoju 1. Ostatecznie mamy 1217i 305 obrazy odpowiednio w zestawie szkoleniowym i walidacyjnym.
TRANSFORMS = transforms.Compose([ transforms.ColorJitter(0.2, 0.2, 0.2, 0.2), transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) paths = glob.glob('../data/room_1/*') # Shuffle the paths random.shuffle(paths) # Create a datasets for training and validation split = int(0.8 * len(paths)) train_data = ImageDataset(paths[:split], TRANSFORMS) valid_data = ImageDataset(paths[split:], TRANSFORMS) # Prepare data for Pytorch model train_loader = DataLoader(train_data, batch_size=32, shuffle=True) valid_loader = DataLoader(valid_data, batch_size=valid_data.__len__())
Zwróć uwagę na batch_size valid_loader. Używamy długości zbioru danych walidacyjnych (tj. 305). Dzięki temu możemy załadować wszystkie dane walidacyjne w jednej iteracji. Jeśli pracujesz z większymi zbiorami danych, może być konieczne użycie mniejszego rozmiaru partii.
Ładujemy wstępnie wytrenowany model ResNet18 (linia 5). Ustawiając model.fc, aktualizujemy ostatnią warstwę (linia 6). Jest to w pełni połączona warstwa od 512 węzłów do naszych 2 docelowych węzłów zmiennych. Będziemy używać optymalizatora Adama do dostrojenia tego modelu (wiersz 9).
output_dim = 2 # x, y device = torch.device('mps') # or 'cuda' if you have a GPU # RESNET 18 model = torchvision.models.resnet18(pretrained=True) model.fc = torch.nn.Linear(512, output_dim) model = model.to(device) optimizer = torch.optim.Adam(model.parameters())
Wytrenowałem model przy użyciu procesora graficznego (linia 2). Nadal będziesz mógł uruchomić kod na procesorze. Dostrajanie nie jest tak kosztowne obliczeniowo, jak szkolenie od zera!
Wreszcie mamy nasz kod szkoleniowy modelu. Trenujemy przez 10 epok, używając MSE jako naszej funkcji straty. Nasz ostateczny model to ten, który ma najniższy MSE w zestawie walidacyjnym.
name = "direction_model_1" # Change this to save a new model # Train the model min_loss = np.inf for epoch in range(10): model = model.train() for images, target in iter(train_loader): images = images.to(device) target = target.to(device) # Zero gradients of parameters optimizer.zero_grad() # Execute model to get outputs output = model(images) # Calculate loss loss = torch.nn.functional.mse_loss(output, target) # Run backpropogation to accumulate gradients loss.backward() # Update model parameters optimizer.step() # Calculate validation loss model = model.eval() images, target = next(iter(valid_loader)) images = images.to(device) target = target.to(device) output = model(images) valid_loss = torch.nn.functional.mse_loss(output, target) print("Epoch: {}, Validation Loss: {}".format(epoch, valid_loss.item())) if valid_loss < min_loss: print("Saving model") torch.save(model, '../models/{}.pth'.format(name)) min_loss = valid_loss
Metryki oceny
W tym momencie chcemy zrozumieć, jak radzi sobie nasz model. Przyglądamy się MSE i wykresom punktowym rzeczywistych i przewidywanych wartości x. Na razie ignorujemy y, ponieważ nie ma to wpływu na kierunek samochodu.
Zestaw szkoleniowy i walidacyjny
Rysunek 2 przedstawia te metryki w zestawie szkoleniowym i walidacyjnym. Ukośna czerwona linia daje doskonałe przewidywania. Istnieje podobna odmiana wokół tej prostej dla x ‹ 0 i x › 0. Innymi słowy, model jest w stanie przewidzieć skręty w lewo i w prawo z podobną dokładnością. Podobna wydajność zbioru uczącego i walidacyjnego wskazuje również, że model nie jest nadmiernie dopasowany.
Do stworzenia powyższego wykresu używamy funkcji model_evaluation. Należy pamiętać, że moduły ładujące dane należy utworzyć w taki sposób, aby ładowały wszystkie dane w pierwszej iteracji.
def model_evaluation(loaders,labels,save_path = None): """Evaluate direction models with mse and scatter plots loaders: list of data loaders labels: list of labels for plot title""" n = len(loaders) fig, axs = plt.subplots(1, n, figsize=(7*n, 6)) # Evalution metrics for i, loader in enumerate(loaders): # Load all data images, target = next(iter(loader)) images = images.to(device) target = target.to(device) output=model(images) # Get x predictions x_pred=output.detach().cpu().numpy()[:,0] x_target=target.cpu().numpy()[:,0] # Calculate MSE mse = mean_squared_error(x_target, x_pred) # Plot predcitons axs[i].scatter(x_target,x_pred) axs[i].plot([-1, 1], [-1, 1], color='r', linestyle='-', linewidth=2) axs[i].set_ylabel('Predicted x', size =15) axs[i].set_xlabel('Actual x', size =15) axs[i].set_title("{0} MSE: {1:.4f}".format(labels[i], mse),size = 18) if save_path != None: fig.savefig(save_path)
Możesz zobaczyć, co mamy na myśli, gdy używamy poniższej funkcji. Stworzyliśmy nowy train_loader, ustawiający wielkość partii na długość zbioru danych szkoleniowych. Ważne jest także załadowanie zapisanego modelu (linia 2). W przeciwnym razie skończysz na modelu wytrenowanym w ostatniej epoce.
# Load saved model model = torch.load('../models/direction_model_1.pth') model.eval() model.to(device) # Create new loader for all data train_loader = DataLoader(train_data, batch_size=train_data.__len__()) # Evaluate model on training and validation set loaders = [train_loader,valid_loader] labels = ["Train","Validation"] # Evaluate on training and validation set model_evaluation(loaders,labels)
Przeprowadzka do nowych lokalizacji
Wyniki wyglądają dobrze! Oczekiwaliśmy, że samochód będzie dobrze się sprawował i tak się stało. Tak było, dopóki nie przenieśliśmy go do nowej lokalizacji:
Zbieramy dane z nowych lokalizacji (pokój 2 i pokój 3). Przeprowadzając ocenę tych obrazów, można zauważyć, że nasz model nie działa równie dobrze. To jest dziwne! Samochód jedzie dokładnie po tym samym torze, więc dlaczego pomieszczenie ma znaczenie?
Debugowanie modelu za pomocą SHAP
Szukamy odpowiedzi w firmie SHAP. Można go wykorzystać do zrozumienia, które piksele są ważne dla danej prognozy. Zaczynamy od załadowania naszego zapisanego modelu (linia 2). SHAP nie został zaimplementowany dla GPU, dlatego ustawiamy urządzenie na CPU (linie 5–6).
# Load saved model model = torch.load('../models/direction_model_1.pth') # Use CPU device = torch.device('cpu') model = model.to(device)
Aby obliczyć wartości SHAP, musimy uzyskać obrazy tła. SHAP zintegruje te obrazy podczas obliczania wartości. Używamy batch_size 100 obrazów. To powinno dać nam rozsądne przybliżenia. Zwiększenie liczby obrazów poprawi przybliżenie, ale także wydłuży czas obliczeń.
#Load 100 images for background shap_loader = DataLoader(train_data, batch_size=100, shuffle=True) background, _ = next(iter(shap_loader)) background = background.to(device)
Tworzymy obiekt wyjaśniający, przekazując nasz model i obrazy tła do funkcji DeepExplainer. Ta funkcja skutecznie przybliża wartości SHAP dla sieci neuronowych. Alternatywnie możesz zastąpić ją funkcją GradientExplainer.
#Create SHAP explainer explainer = shap.DeepExplainer(model, background)
Ładujemy 2 przykładowe obrazy — skręt w prawo i w lewo (linia 2) i przekształcamy je (linia 6). Jest to ważne, ponieważ obrazy powinny mieć ten sam format, jaki został użyty do uczenia modelu. Następnie obliczamy wartości SHAP dla przewidywań dokonanych na podstawie tych obrazów (linia 10).
# Load test images of right and left turn paths = glob.glob('../data/room_1/*') test_images = [Image.open(paths[0]), Image.open(paths[3])] test_images = np.array(test_images) test_input = [TRANSFORMS(img) for img in test_images] test_input = torch.stack(test_input).to(device) # Get SHAP values shap_values = explainer.shap_values(test_input)
Na koniec możemy wyświetlić wartości SHAP za pomocą funkcji image_plot. Najpierw jednak musimy je zrestrukturyzować. Wartości SHAP są zwracane z wymiarami:
( #targets, #images, #channels, #width, #height)
Używamy funkcji transpozycji, więc mamy wymiary:
( #targets, #images, #width, #height, #channels)
Uwaga, przekazaliśmy także oryginalne obrazy do funkcji image_plot. Obrazy test_input wyglądałyby dziwnie z powodu przekształceń.
# Reshape shap values and images for plotting shap_numpy = list(np.array(shap_values).transpose(0,1,3,4,2)) test_numpy = np.array([np.array(img) for img in test_images]) shap.image_plot(shap_numpy, test_numpy,show=False)
Wynik możesz zobaczyć na Rysunku 4. Pierwsza kolumna zawiera oryginalne obrazy. Druga i trzecia kolumna to odpowiednio wartości SHAP dla predykcji x i y. Niebieskie piksele zmniejszyły przewidywanie. Dla porównania, czerwone piksele zwiększyły przewidywanie. Innymi słowy, w przypadku przewidywania x czerwone piksele spowodowały ostrzejszy skręt w prawo.
Teraz gdzieś dochodzimy. Ważnym rezultatem jest to, że model wykorzystuje piksele tła. Można to zobaczyć na rysunku 5, na którym przybliżamy prognozę x dla skrętu w prawo. Innymi słowy, tło jest ważne dla przewidywania. To wyjaśnia słabą wydajność! Kiedy przeprowadziliśmy się do nowego pokoju, tło się zmieniło, a nasze przewidywania stały się niewiarygodne.
Model jest nadmiernie dopasowany do danych z pomieszczenia nr 1. Na każdym obrazie występują te same obiekty i tło. W rezultacie model kojarzy je ze skrętami w lewo i w prawo. Nie mogliśmy tego zidentyfikować w naszej ocenie, ponieważ mamy to samo doświadczenie zarówno w obrazach szkoleniowych, jak i walidacyjnych.
Udoskonalanie modelu
Zależy nam na tym, aby nasz model dobrze spisał się w każdych warunkach. Aby to osiągnąć, spodziewalibyśmy się, że wykorzysta tylko piksele ze ścieżki. Omówmy zatem kilka sposobów zwiększenia wytrzymałości modelu.
Zbieranie nowych danych
Najlepszym rozwiązaniem jest po prostu zebranie większej ilości danych. Mamy już kilka z pokoju 2 i 3. Postępując w ten sam sposób, szkolimy nowy model, korzystając z danych ze wszystkich 3 pokoi. Patrząc na Rysunek 7, teraz działa on lepiej w przypadku obrazów z nowych pomieszczeń.
Mamy nadzieję, że trenując na danych z wielu pokoi, przełamiemy skojarzenia między obrotami a tłem. Na zakrętach w lewo i w prawo znajdują się teraz różne obiekty, ale tor pozostaje ten sam. Model powinien nauczyć się, że dla przewidywania ważny jest tor.
Możemy to potwierdzić, patrząc na wartości SHAP dla nowego modelu. Dotyczy to tych samych zakrętów, które widzieliśmy na Rysunku 4. Piksele tła przywiązują teraz mniejszą wagę. OK, nie jest idealnie, ale gdzieś zmierzamy.
Moglibyśmy nadal gromadzić dane. Im więcej lokalizacji zbierzemy, tym solidniejszy będzie nasz model. Gromadzenie danych może być jednak czasochłonne (i nudne!). Zamiast tego możemy pomyśleć o zwiększeniu ilości danych.
Rozszerzenia danych
Powiększanie danych ma miejsce wtedy, gdy systematycznie lub losowo zmieniamy obrazy za pomocą kodu. Dzięki temu możemy sztucznie wprowadzić szum i zwiększyć rozmiar naszego zbioru danych.
Na przykład możemy podwoić rozmiar naszego zbioru danych, odwracając obrazy wokół osi pionowej. Możemy to zrobić, ponieważ nasz tor jest symetryczny. Jak widać na rysunku 9, usunięcie również może być użyteczną metodą. Obejmuje to uwzględnianie obrazów, z których usunięto obiekty lub całe tło.
Budując solidny model, należy również wziąć pod uwagę takie czynniki, jak warunki oświetleniowe i jakość obrazu. Możemy je symulować za pomocą drgań kolorów lub dodając szum. Jeśli chcesz poznać wszystkie te metody, zapoznaj się z poniższym artykułem.
W powyższym artykule omawiamy również, dlaczego trudno jest stwierdzić, czy te ulepszenia zwiększyły wytrzymałość modelu. Moglibyśmy wdrożyć model w wielu środowiskach, ale jest to czasochłonne. Na szczęście SHAP może być używany jako alternatywa. Podobnie jak w przypadku gromadzenia danych, może nam to dać wgląd w to, jak ulepszenia zmieniły sposób, w jaki model formułuje prognozy.
Mam nadzieję, że podobał Ci się ten artykuł! Możesz mnie wesprzeć, zostając jednym z moich poleconych członków :)
| Twitter | YouTube | Newsletter — zapisz się, aby otrzymać DARMOWY dostęp do kursu Python SHAP
Zbiór danych
Obrazy JatRacer (CC0: domena publiczna) https://www.kaggle.com/datasets/conorsully1/jatracer-images
Bibliografia
SHAP, Przykład PyTorch Deep Wyjaśniacz MNIST https://shap.readthedocs.io/en/latest/example_notebooks/image_examples/image_classification/PyTorch%20Deep%20Explainer%20MNIST%20example.html