Reading in Faces

You can skip that if you’re not interested…

require(png)
require(grid)
require(glmnet)

load_actors <- function(path, names, idx, im_size){
  actors <- data.frame(matrix(nrow=length(names)*length(idx), ncol=im_size*im_size*1))
  a_ind <- 1
  names_col <- c()
  for (name in names){
    for (i in idx){
      im <- readPNG(sprintf("%s/%s%d.png", path, name, i))
      im <- rowMeans(im[,,1:3], dim=2)
      actors[a_ind,]<-im
      names_col <- c(names_col, name)
      a_ind <- a_ind + 1
    }
  }
  actors$name <- names_col
  return(actors)
}

Now, let’s create a training and test sets

names <- c("DanielRadcliffe", "AngieHarmon")


path <- "/media/guerzhoy/Windows/STA303/lectures/faces/cropped128"
im_size <- 128

train_set <- load_actors(path, names, 1:80, im_size)
test_set <- load_actors(path, names, 81:108, im_size)

Here are the first 100 dimensions from the first row of the test set.

test_set[1,1:100]
##          X1        X2        X3       X4        X5        X6        X7
## 1 0.3581699 0.3620915 0.3686275 0.379085 0.4039216 0.4339869 0.4718954
##          X8        X9       X10       X11       X12       X13       X14
## 1 0.5163399 0.5620915 0.5882353 0.5947712 0.5908497 0.5856209 0.5777778
##         X15       X16       X17      X18       X19       X20       X21
## 1 0.5738562 0.5620915 0.5464052 0.530719 0.5228758 0.5163399 0.5111111
##         X22       X23       X24       X25       X26       X27       X28
## 1 0.5084967 0.5058824 0.4980392 0.4928105 0.4849673 0.4732026 0.4575163
##         X29       X30       X31       X32       X33       X34       X35
## 1 0.4457516 0.4444444 0.4457516 0.4392157 0.4339869 0.4352941 0.4300654
##         X36       X37       X38       X39      X40      X41       X42
## 1 0.4222222 0.4196078 0.4156863 0.4013072 0.379085 0.372549 0.3869281
##         X43       X44       X45       X46       X47      X48       X49
## 1 0.4065359 0.4222222 0.4287582 0.4326797 0.4326797 0.427451 0.4405229
##         X50       X51       X52       X53       X54       X55       X56
## 1 0.4366013 0.4352941 0.4366013 0.4392157 0.4392157 0.4326797 0.4339869
##         X57       X58       X59       X60      X61       X62       X63
## 1 0.4287582 0.4196078 0.4196078 0.4248366 0.427451 0.4261438 0.4222222
##         X64       X65       X66       X67       X68       X69       X70
## 1 0.4222222 0.4196078 0.4117647 0.4039216 0.4013072 0.3921569 0.3843137
##        X71       X72       X73       X74      X75       X76       X77
## 1 0.379085 0.3816993 0.3764706 0.3738562 0.372549 0.3738562 0.3751634
##        X78       X79       X80       X81       X82       X83       X84
## 1 0.372549 0.3607843 0.3607843 0.3633987 0.3673203 0.3633987 0.3477124
##         X85      X86       X87       X88      X89       X90       X91
## 1 0.3372549 0.330719 0.3045752 0.2797386 0.248366 0.2104575 0.1660131
##         X92       X93       X94       X95       X96       X97      X98
## 1 0.1320261 0.1189542 0.1254902 0.1542484 0.1803922 0.1660131 0.124183
##         X99       X100
## 1 0.0875817 0.07189542

There are \(128\times 128 + 1 = 16385\) dimensions in total (the last one is the name):

dim(test_set)
## [1]    56 16385

Here are the first ten names:

test_set$name[1:10]
##  [1] "DanielRadcliffe" "DanielRadcliffe" "DanielRadcliffe"
##  [4] "DanielRadcliffe" "DanielRadcliffe" "DanielRadcliffe"
##  [7] "DanielRadcliffe" "DanielRadcliffe" "DanielRadcliffe"
## [10] "DanielRadcliffe"

Let’s display the 4-th row from the training set in a more sensible way:

im <- data.matrix(train_set[4, 1:(im_size*im_size)])
dim(im) <- c(im_size, im_size)
grid.raster(im)

Now, let’s fit the ridge logistic regression model. alpha=0 means we are using ridge logistic regression. We also need to find the best parameter \(\lambda\). This is done using cross-validation.

fit <- glmnet(as.matrix(train_set[,1:(im_size*im_size)]), y=factor(train_set[,(im_size*im_size)+1]), alpha=0, family="binomial")
fit_cv <- cv.glmnet(as.matrix(train_set[,1:(im_size*im_size)]), y=factor(train_set[,(im_size*im_size)+1]), alpha=0, family="binomial")

Looking at fit_cv, the best \(\lambda\) is \(3.13\), at index 97:

fit_cv
## $lambda
##  [1] 272.335972 259.957882 248.142395 236.863940 226.098108 215.821600
##  [7] 206.012176 196.648604 187.710622 179.178884 171.034927 163.261126
## [13] 155.840656 148.757457 141.996201 135.542255 129.381650 123.501055
## [19] 117.887741 112.529561 107.414919 102.532745  97.872474  93.424019
## [25]  89.177753  85.124486  81.255447  77.562262  74.036937  70.671844
## [31]  67.459700  64.393552  61.466766  58.673007  56.006228  53.460658
## [37]  51.030789  48.711361  46.497354  44.383978  42.366657  40.441027
## [43]  38.602920  36.848358  35.173543  33.574852  32.048823  30.592154
## [49]  29.201694  27.874432  26.607496  25.398144  24.243759  23.141843
## [55]  22.090010  21.085985  20.127595  19.212764  18.339515  17.505955
## [61]  16.710283  15.950775  15.225787  14.533752  13.873170  13.242613
## [67]  12.640716  12.066176  11.517750  10.994250  10.494545  10.017551
## [73]   9.562238   9.127620   8.712755   8.316747   7.938738   7.577910
## [79]   7.233482   6.904710   6.590880   6.291314   6.005364   5.732411
## [85]   5.471864   5.223160   4.985759   4.759148   4.542838   4.336359
## [91]   4.139265   3.951129   3.771544   3.600121   3.436490   3.280296
## [97]   3.131202   2.988884
## 
## $cvm
##  [1] 1.1362747 1.0134736 0.9862460 0.9779680 0.9698158 0.9616737 0.9534620
##  [8] 0.9451703 0.9368165 0.9284109 0.9199558 0.9114515 0.9028852 0.8942698
## [15] 0.8856053 0.8768943 0.8681389 0.8594464 0.8506675 0.8417761 0.8328852
## [22] 0.8239305 0.8149813 0.8059878 0.7970068 0.7879677 0.7789444 0.7699200
## [29] 0.7608920 0.7517505 0.7427519 0.7337202 0.7247357 0.7157790 0.7072005
## [36] 0.6985062 0.6896145 0.6808056 0.6719419 0.6632907 0.6546387 0.6459423
## [43] 0.6374648 0.6290443 0.6205396 0.6123019 0.6045001 0.5965120 0.5882322
## [50] 0.5803162 0.5726054 0.5647140 0.5569255 0.5494122 0.5421369 0.5345625
## [57] 0.5270431 0.5201016 0.5131998 0.5060377 0.4991320 0.4924170 0.4859269
## [64] 0.4796461 0.4730550 0.4665785 0.4605271 0.4546780 0.4487385 0.4428954
## [71] 0.4372073 0.4314809 0.4261161 0.4210363 0.4156827 0.4104346 0.4053737
## [78] 0.4003856 0.3958008 0.3915061 0.3868007 0.3820129 0.3775684 0.3734223
## [85] 0.3692550 0.3652459 0.3614991 0.3577789 0.3539451 0.3500542 0.3462760
## [92] 0.3428080 0.3395871 0.3362407 0.3327153 0.3296964 0.3270010 0.3241913
## 
## $cvsd
##  [1] 0.06007727 0.03756447 0.03466333 0.03514948 0.03559074 0.03600186
##  [7] 0.03640083 0.03679110 0.03716927 0.03753267 0.03788075 0.03821373
## [13] 0.03853251 0.03883628 0.03912491 0.03939824 0.03965622 0.03986304
## [19] 0.04007074 0.04028915 0.04048108 0.04066971 0.04083549 0.04099609
## [25] 0.04113565 0.04126652 0.04137694 0.04148642 0.04156929 0.04159316
## [31] 0.04166427 0.04172006 0.04176255 0.04176232 0.04168802 0.04170151
## [37] 0.04172392 0.04171182 0.04174161 0.04170967 0.04169016 0.04170106
## [43] 0.04165909 0.04162573 0.04162714 0.04158533 0.04161793 0.04166129
## [49] 0.04160275 0.04154297 0.04152324 0.04152671 0.04150366 0.04145178
## [55] 0.04142897 0.04143549 0.04142694 0.04141084 0.04138794 0.04140106
## [61] 0.04143635 0.04145183 0.04143187 0.04144642 0.04150340 0.04154502
## [67] 0.04159673 0.04162779 0.04165891 0.04173811 0.04181852 0.04188959
## [73] 0.04196407 0.04203908 0.04212865 0.04223719 0.04235769 0.04246354
## [79] 0.04256190 0.04268127 0.04282081 0.04295021 0.04308154 0.04322840
## [85] 0.04338625 0.04353202 0.04368179 0.04383912 0.04401370 0.04418281
## [91] 0.04435027 0.04452952 0.04471044 0.04489897 0.04511012 0.04533384
## [97] 0.04555850 0.04575308
## 
## $cvup
##  [1] 1.1963520 1.0510381 1.0209093 1.0131175 1.0054065 0.9976755 0.9898629
##  [8] 0.9819614 0.9739858 0.9659436 0.9578365 0.9496653 0.9414177 0.9331061
## [15] 0.9247303 0.9162925 0.9077951 0.8993095 0.8907383 0.8820652 0.8733663
## [22] 0.8646002 0.8558168 0.8469839 0.8381425 0.8292343 0.8203214 0.8114065
## [29] 0.8024613 0.7933437 0.7844161 0.7754403 0.7664983 0.7575413 0.7488885
## [36] 0.7402077 0.7313385 0.7225174 0.7136835 0.7050004 0.6963288 0.6876433
## [43] 0.6791239 0.6706700 0.6621667 0.6538872 0.6461180 0.6381733 0.6298350
## [50] 0.6218592 0.6141286 0.6062407 0.5984291 0.5908640 0.5835659 0.5759980
## [57] 0.5684701 0.5615125 0.5545877 0.5474388 0.5405683 0.5338688 0.5273588
## [64] 0.5210925 0.5145584 0.5081235 0.5021238 0.4963058 0.4903974 0.4846335
## [71] 0.4790258 0.4733705 0.4680802 0.4630754 0.4578114 0.4526718 0.4477314
## [78] 0.4428491 0.4383627 0.4341874 0.4296215 0.4249631 0.4206500 0.4166507
## [85] 0.4126412 0.4087779 0.4051808 0.4016180 0.3979588 0.3942370 0.3906263
## [92] 0.3873375 0.3842975 0.3811397 0.3778254 0.3750303 0.3725595 0.3699444
## 
## $cvlo
##  [1] 1.0761975 0.9759091 0.9515827 0.9428185 0.9342251 0.9256718 0.9170612
##  [8] 0.9083792 0.8996473 0.8908782 0.8820750 0.8732378 0.8643527 0.8554335
## [15] 0.8464804 0.8374961 0.8284827 0.8195834 0.8105968 0.8014869 0.7924041
## [22] 0.7832608 0.7741458 0.7649918 0.7558712 0.7467012 0.7375675 0.7284336
## [29] 0.7193227 0.7101573 0.7010876 0.6920002 0.6829732 0.6740167 0.6655125
## [36] 0.6568047 0.6478906 0.6390938 0.6302003 0.6215810 0.6129485 0.6042412
## [43] 0.5958057 0.5874185 0.5789124 0.5707165 0.5628822 0.5548507 0.5466295
## [50] 0.5387732 0.5310822 0.5231873 0.5154218 0.5079604 0.5007079 0.4931270
## [57] 0.4856162 0.4786908 0.4718118 0.4646367 0.4576956 0.4509651 0.4444951
## [64] 0.4381997 0.4315516 0.4250335 0.4189303 0.4130502 0.4070796 0.4011573
## [71] 0.3953888 0.3895913 0.3841521 0.3789972 0.3735541 0.3681974 0.3630161
## [78] 0.3579220 0.3532389 0.3488249 0.3439799 0.3390627 0.3344869 0.3301939
## [85] 0.3258687 0.3217139 0.3178173 0.3139398 0.3099314 0.3058713 0.3019258
## [92] 0.2982784 0.2948766 0.2913417 0.2876052 0.2843626 0.2814425 0.2784383
## 
## $nzero
##    s0    s1    s2    s3    s4    s5    s6    s7    s8    s9   s10   s11 
## 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 
##   s12   s13   s14   s15   s16   s17   s18   s19   s20   s21   s22   s23 
## 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 
##   s24   s25   s26   s27   s28   s29   s30   s31   s32   s33   s34   s35 
## 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 
##   s36   s37   s38   s39   s40   s41   s42   s43   s44   s45   s46   s47 
## 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 
##   s48   s49   s50   s51   s52   s53   s54   s55   s56   s57   s58   s59 
## 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 
##   s60   s61   s62   s63   s64   s65   s66   s67   s68   s69   s70   s71 
## 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 
##   s72   s73   s74   s75   s76   s77   s78   s79   s80   s81   s82   s83 
## 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 
##   s84   s85   s86   s87   s88   s89   s90   s91   s92   s93   s94   s95 
## 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 16384 
##   s96   s97 
## 16384 16384 
## 
## $name
##            deviance 
## "Binomial Deviance" 
## 
## $glmnet.fit
## 
## Call:  glmnet(x = as.matrix(train_set[, 1:(im_size * im_size)]), y = factor(train_set[,      (im_size * im_size) + 1]), alpha = 0, family = "binomial") 
## 
##           Df       %Dev  Lambda
##   [1,] 16384 -3.203e-15 272.300
##   [2,] 16384  3.470e-01 260.000
##   [3,] 16384  3.549e-01 248.100
##   [4,] 16384  3.626e-01 236.900
##   [5,] 16384  3.700e-01 226.100
##   [6,] 16384  3.773e-01 215.800
##   [7,] 16384  3.847e-01 206.000
##   [8,] 16384  3.922e-01 196.600
##   [9,] 16384  3.997e-01 187.700
##  [10,] 16384  4.074e-01 179.200
##  [11,] 16384  4.150e-01 171.000
##  [12,] 16384  4.227e-01 163.300
##  [13,] 16384  4.305e-01 155.800
##  [14,] 16384  4.383e-01 148.800
##  [15,] 16384  4.462e-01 142.000
##  [16,] 16384  4.541e-01 135.500
##  [17,] 16384  4.621e-01 129.400
##  [18,] 16384  4.700e-01 123.500
##  [19,] 16384  4.781e-01 117.900
##  [20,] 16384  4.861e-01 112.500
##  [21,] 16384  4.942e-01 107.400
##  [22,] 16384  5.023e-01 102.500
##  [23,] 16384  5.104e-01  97.870
##  [24,] 16384  5.186e-01  93.420
##  [25,] 16384  5.267e-01  89.180
##  [26,] 16384  5.349e-01  85.120
##  [27,] 16384  5.430e-01  81.260
##  [28,] 16384  5.512e-01  77.560
##  [29,] 16384  5.593e-01  74.040
##  [30,] 16384  5.675e-01  70.670
##  [31,] 16384  5.756e-01  67.460
##  [32,] 16384  5.837e-01  64.390
##  [33,] 16384  5.917e-01  61.470
##  [34,] 16384  5.998e-01  58.670
##  [35,] 16384  6.078e-01  56.010
##  [36,] 16384  6.157e-01  53.460
##  [37,] 16384  6.236e-01  51.030
##  [38,] 16384  6.315e-01  48.710
##  [39,] 16384  6.393e-01  46.500
##  [40,] 16384  6.470e-01  44.380
##  [41,] 16384  6.547e-01  42.370
##  [42,] 16384  6.623e-01  40.440
##  [43,] 16384  6.698e-01  38.600
##  [44,] 16384  6.773e-01  36.850
##  [45,] 16384  6.847e-01  35.170
##  [46,] 16384  6.920e-01  33.570
##  [47,] 16384  6.992e-01  32.050
##  [48,] 16384  7.063e-01  30.590
##  [49,] 16384  7.134e-01  29.200
##  [50,] 16384  7.203e-01  27.870
##  [51,] 16384  7.272e-01  26.610
##  [52,] 16384  7.339e-01  25.400
##  [53,] 16384  7.406e-01  24.240
##  [54,] 16384  7.471e-01  23.140
##  [55,] 16384  7.536e-01  22.090
##  [56,] 16384  7.599e-01  21.090
##  [57,] 16384  7.662e-01  20.130
##  [58,] 16384  7.723e-01  19.210
##  [59,] 16384  7.783e-01  18.340
##  [60,] 16384  7.842e-01  17.510
##  [61,] 16384  7.900e-01  16.710
##  [62,] 16384  7.957e-01  15.950
##  [63,] 16384  8.013e-01  15.230
##  [64,] 16384  8.068e-01  14.530
##  [65,] 16384  8.122e-01  13.870
##  [66,] 16384  8.174e-01  13.240
##  [67,] 16384  8.226e-01  12.640
##  [68,] 16384  8.276e-01  12.070
##  [69,] 16384  8.325e-01  11.520
##  [70,] 16384  8.374e-01  10.990
##  [71,] 16384  8.421e-01  10.490
##  [72,] 16384  8.467e-01  10.020
##  [73,] 16384  8.512e-01   9.562
##  [74,] 16384  8.556e-01   9.128
##  [75,] 16384  8.599e-01   8.713
##  [76,] 16384  8.641e-01   8.317
##  [77,] 16384  8.682e-01   7.939
##  [78,] 16384  8.721e-01   7.578
##  [79,] 16384  8.760e-01   7.233
##  [80,] 16384  8.798e-01   6.905
##  [81,] 16384  8.835e-01   6.591
##  [82,] 16384  8.871e-01   6.291
##  [83,] 16384  8.906e-01   6.005
##  [84,] 16384  8.940e-01   5.732
##  [85,] 16384  8.973e-01   5.472
##  [86,] 16384  9.006e-01   5.223
##  [87,] 16384  9.037e-01   4.986
##  [88,] 16384  9.068e-01   4.759
##  [89,] 16384  9.097e-01   4.543
##  [90,] 16384  9.126e-01   4.336
##  [91,] 16384  9.154e-01   4.139
##  [92,] 16384  9.182e-01   3.951
##  [93,] 16384  9.208e-01   3.772
##  [94,] 16384  9.234e-01   3.600
##  [95,] 16384  9.259e-01   3.436
##  [96,] 16384  9.283e-01   3.280
##  [97,] 16384  9.307e-01   3.131
##  [98,] 16384  9.330e-01   2.989
##  [99,] 16384  9.352e-01   2.853
## [100,] 16384  9.373e-01   2.723
## 
## $lambda.min
## [1] 2.988884
## 
## $lambda.1se
## [1] 5.471864
## 
## attr(,"class")
## [1] "cv.glmnet"

Now, let’s compute the predictions and see what kind of classification rate we get.

Predictions (log-odds):

predict(fit, as.matrix(test_set[,1:(im_size*im_size)]) ,s=99)
##                1
##  [1,]  0.6525192
##  [2,]  2.7220273
##  [3,]  0.3652630
##  [4,] -0.0223787
##  [5,] -1.1415293
##  [6,] -0.4084607
##  [7,] -0.2796465
##  [8,]  1.0560135
##  [9,]  1.1340886
## [10,]  1.3998971
## [11,] -0.2920484
## [12,]  1.8629664
## [13,]  1.1909852
## [14,]  1.3258180
## [15,]  1.4363487
## [16,]  1.0174629
## [17,]  0.4170194
## [18,]  1.3211269
## [19,]  0.2670056
## [20,]  2.5599361
## [21,]  1.2271168
## [22,]  0.5133525
## [23,]  1.6319309
## [24,]  0.6877331
## [25,]  0.8206764
## [26,]  0.4767357
## [27,]  0.7634759
## [28,]  1.3409136
## [29,] -1.2937397
## [30,] -1.2796628
## [31,] -1.1894561
## [32,] -0.1783895
## [33,] -0.8505973
## [34,] -1.4289506
## [35,] -1.0784371
## [36,] -0.4910984
## [37,] -1.2794197
## [38,] -0.8383343
## [39,] -0.5919297
## [40,] -1.3839425
## [41,]  1.4302030
## [42,] -0.8462872
## [43,] -0.4992556
## [44,] -1.7839323
## [45,] -0.5404087
## [46,] -1.8321558
## [47,] -0.6495346
## [48,] -0.7804304
## [49,] -1.1959140
## [50,] -1.3273931
## [51,] -0.8389319
## [52,] -1.0765960
## [53,] -0.3070338
## [54,] -1.7630432
## [55,] -1.2308185
## [56,] -2.2953734

Correct classification rate:

mean((predict(fit, as.matrix(train_set[,1:(im_size*im_size)]) ,s=90)>0) == (train_set$name == "DanielRadcliffe"))
## [1] 0.94375
mean((predict(fit, as.matrix(test_set[,1:(im_size*im_size)]) ,s=90)>0) == (test_set$name == "DanielRadcliffe"))
## [1] 0.8928571

Let’s see what out model coefficients are like

w<-as.matrix(coef(fit)[2:(im_size*im_size+1),99])
w<-(w-min(w))/(max(w)-min(w))
dim(w) <- c(im_size, im_size)
grid.raster(w)

Let’s try the same, but setting \(\lambda = 0\).

fit0 <- glmnet(as.matrix(train_set[,1:(im_size*im_size)]), y=factor(train_set[,(im_size*im_size)+1]), alpha=0, lambda=0, family="binomial")

mean((predict(fit0, as.matrix(train_set[,1:(im_size*im_size)]) ,s=90)>0) == (train_set$name == "DanielRadcliffe"))
## [1] 1
mean((predict(fit0, as.matrix(test_set[,1:(im_size*im_size)]) ,s=90)>0) == (test_set$name == "DanielRadcliffe"))
## [1] 0.6607143

There is a lot of overfitting. Let’s visualize the coefficients:

w<-as.matrix(coef(fit0)[2:(im_size*im_size+1),1])
w<-(w-min(w))/(max(w)-min(w))
dim(w) <- c(im_size, im_size)
grid.raster(w)

What we see is that the model picked up random noise from the image – it’s hard to see a face at all.

Let’s make the training set smaller – that will cause overfitting even if we use a larger \(\lambda\).

train_small <- rbind(train_set[1:20,], train_set[90:110,])
fit_small <- glmnet(as.matrix(train_small[,1:(im_size*im_size)]), y=factor(train_small[,(im_size*im_size)+1]), alpha=0,  family="binomial")

fit_cv <- cv.glmnet(as.matrix(train_small[,1:(im_size*im_size)]), y=factor(train_small[,(im_size*im_size)+1]), alpha=0, family="binomial")

mean((predict(fit_small, as.matrix(train_small[,1:(im_size*im_size)]) ,s=90)>0) == (train_small$name == "DanielRadcliffe"))
## [1] 1
mean((predict(fit0, as.matrix(test_set[,1:(im_size*im_size)]) ,s=90)>0) == (test_set$name == "DanielRadcliffe"))
## [1] 0.6607143

Let’s visualize the coefficients from the small model.

w<-as.matrix(coef(fit_small)[2:(im_size*im_size+1),1])
w<-(w-min(w))/(max(w)-min(w))
dim(w) <- c(im_size, im_size)
grid.raster(w)

It’s basically the average of a few pictures – you can even see the glasses.