Conditional Probability

Author

Prof. Calvin

Published

February 17, 2025

Abstract:

0. Quarto Type-setting

  • This document is rendered with Quarto, and configured to embed an images using the embed-resources option in the header.
  • If you wish to use a similar header, here’s is the format specification for this document:

1. Setup

sh <- suppressPackageStartupMessages
sh(library(tidyverse))
sh(library(caret))
sh(library(naivebayes)) # bae caught me naivin'
sh(library(tidytext))
wine <- readRDS(gzcon(url("https://github.com/cd-public/D505/raw/master/dat/pinot.rds")))

2. Conditional Probability

  • Calculate the probability that a Pinot comes from Burgundy given it has the word ‘fruit’ in the description.
    • Take \(A\) to be the probability that a Pinot was grown in Burgundy.
    • Take \(B\) to be the probability that Pinot has the word ‘fruit’ in it’s description.

\[ P(A|B) \]

nrow(filter(wine,province=="Burgundy" & str_detect(description,"fruit")))/nrow(filter(wine, str_detect(description,"fruit")))
[1] 0.2196038

3. Naive Bayes Algorithm

  • We train a naive bayes algorithm to classify a wine’s province using:
  1. An 80-20 train-test split.
  2. Three features engineered from the description
  3. 5-fold cross validation.
  • We report Kappa after using the model to predict provinces in the holdout sample.
wino = wine %>% 
  mutate(cherry = str_detect(description,"cherry")) %>% 
  mutate(chocolate = str_detect(description,"chocolate")) %>%
  mutate(earth = str_detect(description,"earth")) %>%
  select(-description)

wine_index <- createDataPartition(wino$province, p = 0.80, list = FALSE)
train <- wino[ wine_index, ]
test <- wino[-wine_index, ]

fit <- train(province ~ .,
             data = train, 
             method = "naive_bayes",
             metric = "Kappa",
             trControl = trainControl(method = "cv", number = 5))

confusionMatrix(predict(fit, test),factor(test$province))
Confusion Matrix and Statistics

                   Reference
Prediction          Burgundy California Casablanca_Valley Marlborough New_York
  Burgundy               148         69                10          13        3
  California              75        699                 8          21       12
  Casablanca_Valley        3          4                 0           0        1
  Marlborough              0          1                 2           2        0
  New_York                 5          3                 4           5       10
  Oregon                   7         15                 2           4        0
                   Reference
Prediction          Oregon
  Burgundy              84
  California           315
  Casablanca_Valley      2
  Marlborough            8
  New_York               7
  Oregon               131

Overall Statistics
                                          
               Accuracy : 0.5918          
                 95% CI : (0.5678, 0.6154)
    No Information Rate : 0.4728          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.3428          
                                          
 Mcnemar's Test P-Value : < 2.2e-16       

Statistics by Class:

                     Class: Burgundy Class: California Class: Casablanca_Valley
Sensitivity                  0.62185            0.8837                 0.000000
Specificity                  0.87526            0.5113                 0.993928
Pos Pred Value               0.45260            0.6186                 0.000000
Neg Pred Value               0.93314            0.8306                 0.984366
Prevalence                   0.14226            0.4728                 0.015541
Detection Rate               0.08846            0.4178                 0.000000
Detection Prevalence         0.19546            0.6754                 0.005977
Balanced Accuracy            0.74856            0.6975                 0.496964
                     Class: Marlborough Class: New_York Class: Oregon
Sensitivity                    0.044444        0.384615       0.23949
Specificity                    0.993243        0.985428       0.97513
Pos Pred Value                 0.153846        0.294118       0.82390
Neg Pred Value                 0.974096        0.990238       0.72523
Prevalence                     0.026898        0.015541       0.32696
Detection Rate                 0.001195        0.005977       0.07830
Detection Prevalence           0.007770        0.020323       0.09504
Balanced Accuracy              0.518844        0.685022       0.60731

4. Frequency Differences

  • We find the three words that most distinguish New York Pinots from all other Pinots.

Calculate relative word count.

wc <- function(df, omits) {
  count <- nrow(df)
  df %>%
    unnest_tokens(word, description) %>% anti_join(stop_words) %>%
    filter(!(word %in% omits)) %>% 
    group_by(word) %>% mutate(total=n()) %>% count() %>%
    ungroup() %>% mutate(n=n/count)
}

Make corresponding dataframes.

omits = c("pinot", "noir", "wine")
wc_ny <- wc(wine %>% filter(province=="New_York") %>% select(description), omits)
Joining with `by = join_by(word)`
wc_no <- wc(wine %>% filter(province!="New_York") %>% select(description), omits)
Joining with `by = join_by(word)`

Calculate difference in relative frequencies.

diff <- wc_ny %>%
    inner_join(wc_no, by = "word", suffix = c("_ny", "_no")) %>%
    mutate(diff = n_ny - n_no) %>%
    arrange(desc(abs(diff)))
    
diff %>% head(3)
# A tibble: 3 × 4
  word     n_ny  n_no  diff
  <chr>   <dbl> <dbl> <dbl>
1 cherry  0.916 0.431 0.485
2 tannins 0.580 0.234 0.346
3 palate  0.565 0.239 0.326

Plot it.

sh(library(plotly))

plot_ly(diff %>% top_n(25, diff), 
        x = ~n_ny, y = ~n_no, z = ~diff, text = ~word, 
        type = 'scatter3d', mode = 'markers+text', 
        marker = list(size = 5, color = ~diff, showscale = TRUE)) %>%
    layout(title = "3D Scatterplot of Word Frequencies",
           scene = list(
               xaxis = list(title = "Frequency in New York Pinots"),
               yaxis = list(title = "Frequency in Other Pinots"),
               zaxis = list(title = "Difference in Frequency")
           ))