ML Assignment Part 2: Diamonds Data

In this part, we will try to find the price of a diamond using the “diamonds” data set which is embedded in R package “ggplot2”.

Before starting, we can look at a summary of the data:

library(tidyverse)
## ── Attaching packages ────────────────────────────────────────────────────────── tidyverse 1.2.1 ──
## ✔ ggplot2 2.2.1     ✔ purrr   0.2.4
## ✔ tibble  1.4.2     ✔ dplyr   0.7.4
## ✔ tidyr   0.8.0     ✔ stringr 1.3.0
## ✔ readr   1.1.1     ✔ forcats 0.3.0
## ── Conflicts ───────────────────────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
summary(diamonds)
##      carat               cut        color        clarity     
##  Min.   :0.2000   Fair     : 1610   D: 6775   SI1    :13065  
##  1st Qu.:0.4000   Good     : 4906   E: 9797   VS2    :12258  
##  Median :0.7000   Very Good:12082   F: 9542   SI2    : 9194  
##  Mean   :0.7979   Premium  :13791   G:11292   VS1    : 8171  
##  3rd Qu.:1.0400   Ideal    :21551   H: 8304   VVS2   : 5066  
##  Max.   :5.0100                     I: 5422   VVS1   : 3655  
##                                     J: 2808   (Other): 2531  
##      depth           table           price             x         
##  Min.   :43.00   Min.   :43.00   Min.   :  326   Min.   : 0.000  
##  1st Qu.:61.00   1st Qu.:56.00   1st Qu.:  950   1st Qu.: 4.710  
##  Median :61.80   Median :57.00   Median : 2401   Median : 5.700  
##  Mean   :61.75   Mean   :57.46   Mean   : 3933   Mean   : 5.731  
##  3rd Qu.:62.50   3rd Qu.:59.00   3rd Qu.: 5324   3rd Qu.: 6.540  
##  Max.   :79.00   Max.   :95.00   Max.   :18823   Max.   :10.740  
##                                                                  
##        y                z         
##  Min.   : 0.000   Min.   : 0.000  
##  1st Qu.: 4.720   1st Qu.: 2.910  
##  Median : 5.710   Median : 3.530  
##  Mean   : 5.735   Mean   : 3.539  
##  3rd Qu.: 6.540   3rd Qu.: 4.040  
##  Max.   :58.900   Max.   :31.800  
## 

The given data are as follows:

We begin by creating train and test data with the given codes in assignment:

set.seed(503)
library(tidyverse)
diamonds_test <- diamonds %>% mutate(diamond_id = row_number()) %>%
group_by(cut, color, clarity) %>% sample_frac(0.2) %>% ungroup()
diamonds_train <- anti_join(diamonds %>% mutate(diamond_id = row_number()),
diamonds_test, by = "diamond_id")
diamonds_train
## # A tibble: 43,143 x 11
##    carat cut       color clarity depth table price     x     y     z
##    <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
##  1 0.230 Ideal     E     SI2      61.5   55.   326  3.95  3.98  2.43
##  2 0.210 Premium   E     SI1      59.8   61.   326  3.89  3.84  2.31
##  3 0.230 Good      E     VS1      56.9   65.   327  4.05  4.07  2.31
##  4 0.290 Premium   I     VS2      62.4   58.   334  4.20  4.23  2.63
##  5 0.240 Very Good J     VVS2     62.8   57.   336  3.94  3.96  2.48
##  6 0.240 Very Good I     VVS1     62.3   57.   336  3.95  3.98  2.47
##  7 0.260 Very Good H     SI1      61.9   55.   337  4.07  4.11  2.53
##  8 0.220 Fair      E     VS2      65.1   61.   337  3.87  3.78  2.49
##  9 0.230 Very Good H     VS1      59.4   61.   338  4.00  4.05  2.39
## 10 0.300 Good      J     SI1      64.0   55.   339  4.25  4.28  2.73
## # ... with 43,133 more rows, and 1 more variable: diamond_id <int>
diamonds_test
## # A tibble: 10,797 x 11
##    carat cut   color clarity depth table price     x     y     z
##    <dbl> <ord> <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
##  1 3.40  Fair  D     I1       66.8   52. 15964  9.42  9.34  6.27
##  2 0.900 Fair  D     SI2      64.7   59.  3205  6.09  5.99  3.91
##  3 0.950 Fair  D     SI2      64.4   60.  3384  6.06  6.02  3.89
##  4 1.00  Fair  D     SI2      65.2   56.  3634  6.27  6.21  4.07
##  5 0.700 Fair  D     SI2      58.1   60.  2358  5.79  5.82  3.37
##  6 1.04  Fair  D     SI2      64.9   56.  4398  6.39  6.34  4.13
##  7 0.700 Fair  D     SI2      65.6   55.  2167  5.59  5.50  3.64
##  8 1.03  Fair  D     SI2      66.4   56.  3743  6.31  6.19  4.15
##  9 1.10  Fair  D     SI2      64.6   54.  4725  6.56  6.49  4.22
## 10 2.01  Fair  D     SI2      59.4   66. 15627  8.20  8.17  4.86
## # ... with 10,787 more rows, and 1 more variable: diamond_id <int>

As we want to make a prediction about price, we create our tree according to that. Also, I just want to see the effect of carat, cut, color and clarity on the diamond price:

library(rpart)
library(rpart.plot)
data("diamonds")
diamond_data <- diamonds %>% select(carat,cut,color,clarity,price)
diamond_model <- rpart(price ~ ., data=diamond_data)
rpart.plot(diamond_model)

When all data is considered in the first node, we can see that average price is 3933 dollars. When the tree is divided according to carat information, we can see that majority of the diamonds (65%) have lower carat values than 0.99, and they have lower average price of 1633 dollars. On the other hand, diamonds which have higher carat values are much more expensive with an average price of 8142 dollars.

We can look for a similar relationship between for example cut and price, this time by plotting:

ggplot(aes(x = cut, y = price), data = diamonds) + 
  geom_point(alpha = 0.5, size = 1, position = 'jitter',aes(color=cut)) 

Here we can observe that as cut goes from fair to ideal, price of diamond increases as expected.

For price prediction, I have 8found a different model on the internet, and I would like to refer that model with some example (Reference:https://rstudio-pubs-static.s3.amazonaws.com/94067_d1fdfafd20b14725a2578647031760c2.html):

library(memisc)
## Loading required package: lattice
## Loading required package: MASS
## 
## Attaching package: 'MASS'
## The following object is masked from 'package:dplyr':
## 
##     select
## 
## Attaching package: 'memisc'
## The following objects are masked from 'package:dplyr':
## 
##     collect, recode, rename
## The following objects are masked from 'package:stats':
## 
##     contr.sum, contr.treatment, contrasts
## The following object is masked from 'package:base':
## 
##     as.array
m1 <- lm(I(log10(price)) ~ I(carat^(1/3)), data = diamonds)
m2 <- update(m1,~ . +carat)
m3 <- update(m2,~ . +cut)
m4 <- update(m3,~ . +color)
m5 <- update(m4,~ . +clarity)
mtable(m1,m2,m3,m4,m5)
## 
## Calls:
## m1: lm(formula = I(log10(price)) ~ I(carat^(1/3)), data = diamonds)
## m2: lm(formula = I(log10(price)) ~ I(carat^(1/3)) + carat, data = diamonds)
## m3: lm(formula = I(log10(price)) ~ I(carat^(1/3)) + carat + cut, 
##     data = diamonds)
## m4: lm(formula = I(log10(price)) ~ I(carat^(1/3)) + carat + cut + 
##     color, data = diamonds)
## m5: lm(formula = I(log10(price)) ~ I(carat^(1/3)) + carat + cut + 
##     color + clarity, data = diamonds)
## 
## ==============================================================================================
##                        m1             m2             m3             m4              m5        
## ----------------------------------------------------------------------------------------------
##   (Intercept)          1.225***       0.451***       0.380***       0.405***        0.180***  
##                       (0.003)        (0.008)        (0.008)        (0.007)         (0.004)    
##   I(carat^(1/3))       2.414***       3.721***       3.780***       3.665***        3.971***  
##                       (0.003)        (0.014)        (0.013)        (0.012)         (0.007)    
##   carat                              -0.494***      -0.505***      -0.431***       -0.474***  
##                                      (0.005)        (0.005)        (0.004)         (0.003)    
##   cut: .L                                            0.097***       0.097***        0.052***  
##                                                     (0.002)        (0.002)         (0.001)    
##   cut: .Q                                           -0.027***      -0.027***       -0.013***  
##                                                     (0.002)        (0.001)         (0.001)    
##   cut: .C                                            0.022***       0.022***        0.006***  
##                                                     (0.001)        (0.001)         (0.001)    
##   cut: ^4                                            0.008***       0.008***       -0.001     
##                                                     (0.001)        (0.001)         (0.001)    
##   color: .L                                                        -0.162***       -0.191***  
##                                                                    (0.001)         (0.001)    
##   color: .Q                                                        -0.056***       -0.040***  
##                                                                    (0.001)         (0.001)    
##   color: .C                                                         0.001          -0.006***  
##                                                                    (0.001)         (0.001)    
##   color: ^4                                                         0.012***        0.005***  
##                                                                    (0.001)         (0.001)    
##   color: ^5                                                        -0.007***       -0.001*    
##                                                                    (0.001)         (0.001)    
##   color: ^6                                                        -0.010***        0.001     
##                                                                    (0.001)         (0.001)    
##   clarity: .L                                                                       0.394***  
##                                                                                    (0.001)    
##   clarity: .Q                                                                      -0.104***  
##                                                                                    (0.001)    
##   clarity: .C                                                                       0.057***  
##                                                                                    (0.001)    
##   clarity: ^4                                                                      -0.027***  
##                                                                                    (0.001)    
##   clarity: ^5                                                                       0.011***  
##                                                                                    (0.001)    
##   clarity: ^6                                                                      -0.001     
##                                                                                    (0.001)    
##   clarity: ^7                                                                       0.014***  
##                                                                                    (0.001)    
## ----------------------------------------------------------------------------------------------
##   R-squared            0.924          0.935          0.939          0.951           0.984     
##   adj. R-squared       0.924          0.935          0.939          0.951           0.984     
##   sigma                0.122          0.112          0.109          0.097           0.056     
##   F               652012.063     387489.366     138654.523      87959.467      173791.084     
##   p                    0.000          0.000          0.000          0.000           0.000     
##   Log-likelihood   37025.211      41356.392      43150.294      49222.951       79078.982     
##   Deviance           800.248        681.522        637.665        509.103         168.282     
##   AIC             -74044.422     -82704.783     -86284.589     -98417.901     -158115.964     
##   BIC             -74017.735     -82669.201     -86213.424     -98293.362     -157929.156     
##   N                53940          53940          53940          53940           53940         
## ==============================================================================================

By using the function “lm”, we create a linear model that predicts the diamond price with 5 steps.

We will test this model with 2 random diamond data as follows:

print(diamonds[1386,])
## # A tibble: 1 x 10
##   carat cut       color clarity depth table price     x     y     z
##   <dbl> <ord>     <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1 0.310 Very Good G     VS2      61.4   55.   559  4.38  4.41  2.69
print(diamonds[24596,])
## # A tibble: 1 x 10
##   carat cut   color clarity depth table price     x     y     z
##   <dbl> <ord> <ord> <ord>   <dbl> <dbl> <int> <dbl> <dbl> <dbl>
## 1  1.53 Ideal G     VS2      62.8   57. 12907  7.48  7.43  4.63
thisDiamond <- data.frame(carat=0.31, cut='Very Good',
                          color='G',clarity='VS2')
modelEstimate <- predict(m5,newdata = thisDiamond,
                         interval = "prediction",level = .95)
10^modelEstimate
##        fit     lwr      upr
## 1 578.7647 449.769 744.7569

Actual price: 559

Predicted price: 578

Difference: 3.4%

thisDiamond <- data.frame(carat=1.53, cut='Ideal',
                          color='G',clarity='VS2')
modelEstimate <- predict(m5,newdata = thisDiamond,
                         interval = "prediction",level = .95)
10^modelEstimate
##        fit      lwr     upr
## 1 12384.37 9624.169 15936.2

Actual price: 12907

Predicted price: 12384

Difference: 4%

So, we can say that this linear model can make a good prediction with approximately 3-4% deviation from original price.