Введение

Давайте научим компьютер читать рукописные цифры с помощью классификатора случайного леса в R.

Мы будем использовать набор данных рукописных цифр MNIST для обучения и тестирования нашей модели. Вы можете самостоятельно загрузить набор данных с этой страницы GitHub: Файлы MNIST в формате PNG

Подготовьте данные

Хитрость в том, чтобы сделать эту работу, чтобы закодировать изображения в формате, который может понять наш компьютер. Каждое изображение состоит из матрицы значений RGB, поэтому мы можем использовать функцию png::readPNG() из пакета {png} для считывания этих значений в матрицу R.

image <- png::readPNG("train/0/16585.png")

dim(image)

#> [1] 28 28

Каждое из изображений в наборе данных MNIST имеет размер 28 на 28 пикселей. Мы можем сгладить это по одному измерению, чтобы создать набор данных с 784 функциями.

# Flatten a Matrix
image_flat <- as.vector(
    png::readPNG("train/0/16585.png")
)

length(image_flat)

#> [1] 784

Если бы изображения были намного больше, мы могли бы сначала изменить их размер, используя функцию imager::resize() из пакета {imager}, чтобы предотвратить ошибку нехватки памяти.

Прыжки прямо в

Один из самых простых способов начать работу — создать новый проект в RStudio из системы управления версиями (т. е. GitHub).

После создания нашего проекта наш каталог должен выглядеть так.

list.files()

#> [1] "download-and-convert-to-png.py"
#> [2] "LICENSE"                  
#> [3] "README.md"                     
#> [4] "test"                          
#> [5] "test.csv"                      
#> [6] "train"                         
#> [7] "train.csv" 

В каталоге есть два файла: train.csv и test.csv с соответствующими путями к файлам и метками для каждого изображения. Мы можем прочитать это в нашем сеансе, используя readr::read_csv().

# Read the labels
training_images <- readr::read_csv(
    "train.csv",
    col_types = "cf"
)

testing_images <- readr::read_csv(
    "test.csv",
    col_types = "cf"
)

head(training_images)

#> # A tibble: 6 × 2
#>   filepath          label
#>   <chr>             <fct>
#> 1 train/0/16585.png 0    
#> 2 train/0/24537.png 0    
#> 3 train/0/25629.png 0    
#> 4 train/0/20751.png 0    
#> 5 train/0/34730.png 0    
#> 6 train/0/15926.png 0 

Обратите внимание, что я указал типы столбцов CSV-файла, используя параметр col_types = “cf” (символ и фактор), потому что пакет {parsnip} требует, чтобы метка классификации была фактором.

Создание обучающего набора данных

Следующим шагом является создание обучающего набора данных. Мы можем создать пустую матрицу размером с набор данных и просмотреть каждый из файлов изображений, чтобы прочитать их как матрицу и сгладить до вектора.

# Prepare Training Data
n <- length(training_images$filepath)
training_matrix <- matrix(
    nrow = n,
    ncol = 28 * 28
)

for (i in 1:n){
    training_matrix[i, ] <- 
        as.vector(
            png::readPNG(
                training_images$filepath[i]
            )
        )
}

Чтобы данные обучения работали с функцией подбора модели, мы преобразуем их во фрейм данных и связываем столбцы с метками.

training_data <- cbind(
    dplyr::select(training_images, label),
    as.data.frame(
        training_matrix
    )
)

Подгонка модели

Теперь мы можем подогнать модель к обучающим данным. Мы также можем сохранить результат в виде файла .rds, чтобы нам не приходилось повторять процесс обучения каждый раз, когда мы хотим делать прогнозы. Моему компьютеру требуется около 6 минут, чтобы подогнать модель.

# Model Fitting
rf_fit <- parsnip::fit(
    parsnip::rand_forest(
        mode = "classification"
    ),
    data = training_data,
    formula = label ~ .
)

# Save model
saveRDS(rf_fit, "mnist_model_fit.rds")

# rf_fit <- readRDS("mnist_model_fit.rds")

Создание набора данных для тестирования

Теперь мы можем использовать тот же процесс, что и с набором данных для обучения, для подготовки набора данных для тестирования.

# Prepare Test Data
n_test <- length(testing_images$filepath)
testing_matrix <- matrix(
    nrow = n,
    ncol = 28 * 28
)

for (i in 1:n_test){
    testing_matrix[i, ] <- 
        as.vector(
            png::readPNG(
                testing_images$filepath[i]
            )
        )
}

testing_data <- na.omit(
    cbind(
        dplyr::select(testing_images, label),
        as.data.frame(
            testing_matrix
        )
    )
)

Я добавил na.omit() вокруг обучающих данных, потому что некоторые значения отсутствовали, и это вызывало проблемы с функцией прогнозирования.

Делать предсказания

Давайте сделаем некоторые прогнозы и посмотрим, как мы это сделали! Мы можем использовать функцию предсказания из базы R и передать ей объект пастернака и набор тестовых данных. Функция metrics() из пакета {yardstick} упрощает расчет показателей производительности.

# Model Evaluation
predictions <- predict(
    rf_fit,
    testing_data
)

final_result <-
    dplyr::bind_cols(
        predictions,
        dplyr::select(
            testing_data,
            label
        )
    )

yardstick::metrics(
    final_result,
    truth = "label",
    estimate = ".pred_class"
)

#> # A tibble: 2 × 3
#>   .metric  .estimator .estimate
#>   <chr>    <chr>          <dbl>
#> 1 accuracy multiclass     0.97 
#> 2 kap      multiclass     0.967

Точность 97% процентов для нашей модели! Совсем неплохо. Так сколько же это было неправильно?

wrong_idx <- which(final_result$label != final_result$.pred_class)
right_idx <- which(final_result$label == final_result$.pred_class)

length(wrong_idx)

#> [1] 300

300 из 10 000 тестовых изображений были помечены неправильно. Давайте нарисуем несколько случайных из них, чтобы понять, почему.

random_right <- sample(right_idx, 3)
random_wrong <- sample(wrong_idx, 3)

# Plot the mistakes
ggplot2::ggplot(
    data = data.frame(
        x = seq(1, 10, length.out = 6),
        y = 1,
        images = testing_images$filepath[c(random_right, random_wrong)]
    ),
    ggplot2::aes(
        x, 
        y,
        image = images,
        label = paste(final_result$.pred_class[c(random_right, random_wrong)]) 
    )
) +
    ggimage::geom_image(
        size=.10
    ) +
    ggplot2::scale_y_continuous(
        limits = c(0, 2)
    ) +
    ggplot2::scale_x_continuous(
        limits = c(0, 11)
    ) +
    ggplot2::geom_text(
        size = 10,
        nudge_y = 0.25,
        color = c("green", "green", "green", "red", "red", "red")
    ) +
    ggplot2::theme_void()

Мы видим, что модель была близка, потому что числа, которые она пометила неправильно, трудно различить.

Заключение

Надеюсь, вам понравилась эта статья. Спасибо за чтение, и я желаю вам всего наилучшего! До следующего.

Код из статьи: mnist_classification.R