Objective of the analysis to develop model to estimate diamond prices.
During the analysis, diamonds data set from ggplot2 package is used. And here you can find the overview of the data:
str(diamonds)
## tibble [53,940 x 10] (S3: tbl_df/tbl/data.frame)
## $ carat : num [1:53940] 0.23 0.21 0.23 0.29 0.31 0.24 0.24 0.26 0.22 0.23 ...
## $ cut : Ord.factor w/ 5 levels "Fair"<"Good"<..: 5 4 2 4 2 3 3 3 1 3 ...
## $ color : Ord.factor w/ 7 levels "D"<"E"<"F"<"G"<..: 2 2 2 6 7 7 6 5 2 5 ...
## $ clarity: Ord.factor w/ 8 levels "I1"<"SI2"<"SI1"<..: 2 3 5 4 2 6 7 3 4 5 ...
## $ depth : num [1:53940] 61.5 59.8 56.9 62.4 63.3 62.8 62.3 61.9 65.1 59.4 ...
## $ table : num [1:53940] 55 61 65 58 58 57 57 55 61 61 ...
## $ price : int [1:53940] 326 326 327 334 335 336 336 337 337 338 ...
## $ x : num [1:53940] 3.95 3.89 4.05 4.2 4.34 3.94 3.95 4.07 3.87 4 ...
## $ y : num [1:53940] 3.98 3.84 4.07 4.23 4.35 3.96 3.98 4.11 3.78 4.05 ...
## $ z : num [1:53940] 2.43 2.31 2.31 2.63 2.75 2.48 2.47 2.53 2.49 2.39 ...
summary(diamonds)
## carat cut color clarity depth
## Min. :0.2000 Fair : 1610 D: 6775 SI1 :13065 Min. :43.00
## 1st Qu.:0.4000 Good : 4906 E: 9797 VS2 :12258 1st Qu.:61.00
## Median :0.7000 Very Good:12082 F: 9542 SI2 : 9194 Median :61.80
## Mean :0.7979 Premium :13791 G:11292 VS1 : 8171 Mean :61.75
## 3rd Qu.:1.0400 Ideal :21551 H: 8304 VVS2 : 5066 3rd Qu.:62.50
## Max. :5.0100 I: 5422 VVS1 : 3655 Max. :79.00
## J: 2808 (Other): 2531
## table price x y
## Min. :43.00 Min. : 326 Min. : 0.000 Min. : 0.000
## 1st Qu.:56.00 1st Qu.: 950 1st Qu.: 4.710 1st Qu.: 4.720
## Median :57.00 Median : 2401 Median : 5.700 Median : 5.710
## Mean :57.46 Mean : 3933 Mean : 5.731 Mean : 5.735
## 3rd Qu.:59.00 3rd Qu.: 5324 3rd Qu.: 6.540 3rd Qu.: 6.540
## Max. :95.00 Max. :18823 Max. :10.740 Max. :58.900
##
## z
## Min. : 0.000
## 1st Qu.: 2.910
## Median : 3.530
## Mean : 3.539
## 3rd Qu.: 4.040
## Max. :31.800
##
diamonds <- filter(diamonds, x>0.00, y>0.00, z>0.00)
summary(diamonds)
## carat cut color clarity depth
## Min. :0.2000 Fair : 1609 D: 6774 SI1 :13063 Min. :43.00
## 1st Qu.:0.4000 Good : 4902 E: 9797 VS2 :12254 1st Qu.:61.00
## Median :0.7000 Very Good:12081 F: 9538 SI2 : 9185 Median :61.80
## Mean :0.7977 Premium :13780 G:11284 VS1 : 8170 Mean :61.75
## 3rd Qu.:1.0400 Ideal :21548 H: 8298 VVS2 : 5066 3rd Qu.:62.50
## Max. :5.0100 I: 5421 VVS1 : 3654 Max. :79.00
## J: 2808 (Other): 2528
## table price x y
## Min. :43.00 Min. : 326 Min. : 3.730 Min. : 3.680
## 1st Qu.:56.00 1st Qu.: 949 1st Qu.: 4.710 1st Qu.: 4.720
## Median :57.00 Median : 2401 Median : 5.700 Median : 5.710
## Mean :57.46 Mean : 3931 Mean : 5.732 Mean : 5.735
## 3rd Qu.:59.00 3rd Qu.: 5323 3rd Qu.: 6.540 3rd Qu.: 6.540
## Max. :95.00 Max. :18823 Max. :10.740 Max. :58.900
##
## z
## Min. : 1.07
## 1st Qu.: 2.91
## Median : 3.53
## Mean : 3.54
## 3rd Qu.: 4.04
## Max. :31.80
##
cat("Total number of missing values in diamonds data set:", sum(is.na(diamonds)))
## Total number of missing values in diamonds data set: 0
Let’s check whether we can observe any pattern from the visulation for independent variables. The first variable is “x” providing length in mm of diamonds. It seems there is an exponential relationship between price and x.
ggplot(diamonds, aes(x, price)) +
geom_point(alpha=0.5) +
geom_smooth(method="lm") +
labs(x="X",
y="Price",
title="Price vs X")
## `geom_smooth()` using formula 'y ~ x'
Let’s take the logarithm of price to check whether there is a linear relationship.
ggplot(diamonds, aes(x, log(price))) +
geom_point(alpha=0.5) +
geom_smooth(method="lm") +
labs(x="X",
y="log(Price)",
title="log(Price) vs X")
## `geom_smooth()` using formula 'y ~ x'
It’s the same for both “y” (width in mm) and “z” (depth in mm), too. So, let’s quickly take the logarithm and re-visualize the graphs.
ggplot(diamonds, aes(y, price)) +
geom_point(alpha=0.5) +
geom_smooth(method="lm") +
labs(x="Y",
y="Price",
title="Price vs Y")
## `geom_smooth()` using formula 'y ~ x'
ggplot(diamonds, aes(y, log(price))) +
geom_point(alpha=0.5) +
geom_smooth(method="lm") +
labs(x="Y",
y="log(Price)",
title="log(Price) vs Y")
## `geom_smooth()` using formula 'y ~ x'
ggplot(diamonds, aes(z, price)) +
geom_point(alpha=0.5) +
geom_smooth(method="lm") +
labs(x="Z",
y="Price",
title="Price vs Z")
## `geom_smooth()` using formula 'y ~ x'
ggplot(diamonds, aes(z, log(price))) +
geom_point(alpha=0.5) +
geom_smooth(method="lm") +
labs(x="Z",
y="log(Price)",
title="log(Price) vs Z")
## `geom_smooth()` using formula 'y ~ x'
So, it’s better to convert price to log(price) for the sake of analysis.
ggplot(diamonds, aes(carat, price)) +
geom_point(alpha=0.5) +
geom_smooth(method="lm")
## `geom_smooth()` using formula 'y ~ x'
labs(x="Carat",
y="Price",
title="Price vs Carat")
## $x
## [1] "Carat"
##
## $y
## [1] "Price"
##
## $title
## [1] "Price vs Carat"
##
## attr(,"class")
## [1] "labels"
ggplot(diamonds, aes(carat, log(price))) +
geom_point(alpha=0.5) +
geom_smooth(method="lm") +
labs(x="Carat",
y="log(Price)",
title="log(Price) vs Carat")
## `geom_smooth()` using formula 'y ~ x'
ggplot(diamonds, aes(log(carat), log(price))) +
geom_point(alpha=0.5) +
geom_smooth(method="lm") +
labs(x="log(Carat)",
y="log(Price)",
title="log(Price) vs log(Carat)")
## `geom_smooth()` using formula 'y ~ x'
options(scipen=999)
Let’s check the price distribution by color, cut and clarity:
color_diamonds <- diamonds %>%
group_by(color) %>%
summarize(avg_col=mean(price))
## `summarise()` ungrouping output (override with `.groups` argument)
ggplot(color_diamonds, aes(color, avg_col)) +
geom_col() +
labs(x="Color",
y="Price",
title="Color Distribution by Price")
cut_diamonds <- diamonds %>%
group_by(cut) %>%
summarize(avg_cut=mean(price))
## `summarise()` ungrouping output (override with `.groups` argument)
ggplot(cut_diamonds, aes(cut, avg_cut)) +
geom_col() +
labs(x="Cut",
y="Price",
title="Cut Distribution by Price")
clar_diamonds <- diamonds %>%
group_by(clarity) %>%
summarize(avg_clar=mean(price))
## `summarise()` ungrouping output (override with `.groups` argument)
ggplot(clar_diamonds, aes(clarity, avg_clar)) +
geom_col() +
labs(x="Clarity",
y="Price",
title="Clarity Distribution by Price")
Here, you can find the combined version of the distribution above:
ggplot(diamonds, aes(color, price, color=cut)) +
geom_jitter(alpha=0.5) +
facet_wrap(~clarity, ncol=2) +
labs(x="Color",
y="Price",
color="Cut",
title="Price vs Color, Cut and Clarity") +
scale_color_brewer(palette="Dark2")
From the plot below, we can also infer that x, y, z and price have a positive correlation as all increases simultaneously.
ggplot(diamonds, aes(x, y, color=z, size=price)) + geom_point(alpha=0.3)
For the sake of supervised learning, train and test data are split as shown below:
set.seed(503)
diamonds_test <- diamonds %>%
mutate(diamond_id = row_number()) %>%
group_by(cut, color, clarity) %>%
sample_frac(0.2) %>%
ungroup()
diamonds_train <- anti_join(diamonds %>%
mutate(diamond_id = row_number()),diamonds_test, by = "diamond_id")
CART model is constructed as follows:
diamonds_model <- rpart(price ~ ., data=diamonds_train[-11])
fancyRpartPlot(diamonds_model, type =5, digits = 3)
diamonds_in_sample <- predict(diamonds_model)
A linear model with the existing values has R squared value of 92% whereas it is 98% when we convert price and carat to log(price) and log(carat).
lm_diamonds1 <- lm(price~carat + cut + clarity + color + x + y + z, data=diamonds_train)
summary(lm_diamonds1)
##
## Call:
## lm(formula = price ~ carat + cut + clarity + color + x + y +
## z, data = diamonds_train)
##
## Residuals:
## Min 1Q Median 3Q Max
## -21743.7 -589.9 -182.2 375.9 10800.0
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 531.94 93.43 5.693 0.0000000125483 ***
## carat 11437.31 56.54 202.279 < 0.0000000000000002 ***
## cut.L 739.12 22.45 32.925 < 0.0000000000000002 ***
## cut.Q -336.42 19.64 -17.127 < 0.0000000000000002 ***
## cut.C 183.64 16.97 10.821 < 0.0000000000000002 ***
## cut^4 -15.17 13.60 -1.116 0.2646
## clarity.L 4109.21 33.74 121.779 < 0.0000000000000002 ***
## clarity.Q -1939.16 31.55 -61.464 < 0.0000000000000002 ***
## clarity.C 982.37 26.98 36.413 < 0.0000000000000002 ***
## clarity^4 -372.00 21.52 -17.283 < 0.0000000000000002 ***
## clarity^5 244.93 17.56 13.950 < 0.0000000000000002 ***
## clarity^6 11.27 15.28 0.737 0.4608
## clarity^7 87.10 13.48 6.461 0.0000000001053 ***
## color.L -1966.45 19.35 -101.643 < 0.0000000000000002 ***
## color.Q -686.17 17.59 -39.009 < 0.0000000000000002 ***
## color.C -162.04 16.40 -9.878 < 0.0000000000000002 ***
## color^4 27.87 15.07 1.850 0.0643 .
## color^5 -101.17 14.23 -7.108 0.0000000000012 ***
## color^6 -58.13 12.94 -4.493 0.0000070546868 ***
## x -964.57 33.97 -28.398 < 0.0000000000000002 ***
## y 37.79 19.41 1.947 0.0516 .
## z -287.83 33.96 -8.476 < 0.0000000000000002 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 1126 on 43104 degrees of freedom
## Multiple R-squared: 0.9205, Adjusted R-squared: 0.9205
## F-statistic: 2.377e+04 on 21 and 43104 DF, p-value: < 0.00000000000000022
lm_diamonds2 <- lm(I(log(price))~I(log(carat)) + cut + clarity + color + x + y + z, data=diamonds_train)
summary(lm_diamonds2)
##
## Call:
## lm(formula = I(log(price)) ~ I(log(carat)) + cut + clarity +
## color + x + y + z, data = diamonds_train)
##
## Residuals:
## Min 1Q Median 3Q Max
## -1.07021 -0.08534 0.00005 0.08293 1.91805
##
## Coefficients:
## Estimate Std. Error t value Pr(>|t|)
## (Intercept) 7.8836485 0.0339040 232.529 < 0.0000000000000002 ***
## I(log(carat)) 1.7158744 0.0100178 171.282 < 0.0000000000000002 ***
## cut.L 0.1122528 0.0027061 41.482 < 0.0000000000000002 ***
## cut.Q -0.0313798 0.0023365 -13.430 < 0.0000000000000002 ***
## cut.C 0.0141317 0.0020153 7.012 0.00000000000237803 ***
## cut^4 0.0003575 0.0016148 0.221 0.824809
## clarity.L 0.9146410 0.0040094 228.124 < 0.0000000000000002 ***
## clarity.Q -0.2463606 0.0037368 -65.928 < 0.0000000000000002 ***
## clarity.C 0.1337691 0.0031973 41.838 < 0.0000000000000002 ***
## clarity^4 -0.0656548 0.0025535 -25.712 < 0.0000000000000002 ***
## clarity^5 0.0269347 0.0020826 12.933 < 0.0000000000000002 ***
## clarity^6 -0.0011425 0.0018127 -0.630 0.528510
## clarity^7 0.0321599 0.0015995 20.106 < 0.0000000000000002 ***
## color.L -0.4436068 0.0022821 -194.389 < 0.0000000000000002 ***
## color.Q -0.0986197 0.0020836 -47.331 < 0.0000000000000002 ***
## color.C -0.0152686 0.0019462 -7.845 0.00000000000000441 ***
## color^4 0.0120571 0.0017879 6.744 0.00000000001561634 ***
## color^5 -0.0031565 0.0016888 -1.869 0.061616 .
## color^6 0.0015729 0.0015351 1.025 0.305544
## x 0.0821100 0.0056181 14.615 < 0.0000000000000002 ***
## y -0.0023189 0.0023031 -1.007 0.314019
## z 0.0142430 0.0041079 3.467 0.000526 ***
## ---
## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
##
## Residual standard error: 0.1336 on 43104 degrees of freedom
## Multiple R-squared: 0.9827, Adjusted R-squared: 0.9827
## F-statistic: 1.166e+05 on 21 and 43104 DF, p-value: < 0.00000000000000022