GLM em R: Modelo Linear Generalizado com Exemplo

O que é regressão logística?

A regressão logística é usada para prever uma classe, ou seja, uma probabilidade. A regressão logística pode prever um resultado binário com precisão.

Imagine que você deseja prever se um empréstimo será negado / aceito com base em muitos atributos. A regressão logística é da forma 0/1. y = 0 se um empréstimo for rejeitado, y = 1 se aceito.

Um modelo de regressão logística difere do modelo de regressão linear de duas maneiras.

  • Em primeiro lugar, a regressão logística aceita apenas entrada dicotômica (binária) como uma variável dependente (ou seja, um vetor de 0 e 1).
  • Em segundo lugar, o resultado é medido pela seguinte função de ligação probabilística chamada sigmóide devido à sua forma de S:

A saída da função está sempre entre 0 e 1. Verifique a imagem abaixo

A função sigmóide retorna valores de 0 a 1. Para a tarefa de classificação, precisamos de uma saída discreta de 0 ou 1.

Para converter um fluxo contínuo em valor discreto, podemos definir um limite de decisão em 0,5. Todos os valores acima deste limite são classificados como 1

Neste tutorial, você aprenderá

Como criar um modelo de revestimento generalizado (GLM)

Vamos usar o adulto conjunto de dados para ilustrar a regressão logística. O 'adulto' é um ótimo conjunto de dados para a tarefa de classificação. O objetivo é prever se a renda anual em dólares de um indivíduo será superior a 50.000. O conjunto de dados contém 46.033 observações e dez recursos:

  • idade: idade do indivíduo. Numérico
  • educação: Nível educacional do indivíduo. Fator.
  • estado marital: estado civil do indivíduo. Fator ou seja, nunca casado, cônjuge casado, ...
  • gênero: gênero do indivíduo. Fator, ou seja, masculino ou feminino
  • renda: Variável alvo. Renda acima ou abaixo de 50K. Fator, ou seja,> 50K,<=50K

entre outros

 library(dplyr) data_adult <-read.csv('https://raw.githubusercontent.com/on2vhf-edu/R-Programming/master/adult.csv') glimpse(data_adult) 

Saída:

 Observations: 48,842 Variables: 10 $ x 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,... $ age 25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26... $ workclass Private, Private, Local-gov, Private, ?, Private,... $ education 11th, HS-grad, Assoc-acdm, Some-college, Some-col... $ educational.num 7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,... $ marital.status Never-married, Married-civ-spouse, Married-civ-sp... $ race Black, White, White, Black, White, White, Black, ... $ gender Male, Male, Male, Male, Female, Male, Male, Male,... $ hours.per.week 40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39... $ income <=50K, 50K,>50K, <=50K, <=50K, 5... 

Vamos proceder da seguinte forma:

  • Etapa 1: verificar as variáveis ​​contínuas
  • Etapa 2: verificar as variáveis ​​do fator
  • Etapa 3: Engenharia de recursos
  • Etapa 4: estatística de resumo
  • Etapa 5: conjunto de treinamento / teste
  • Etapa 6: construir o modelo
  • Etapa 7: avalie o desempenho do modelo
  • etapa 8: melhorar o modelo

Sua tarefa é prever qual indivíduo terá uma receita superior a 50 mil.

Neste tutorial, cada etapa será detalhada para realizar uma análise em um conjunto de dados real.

Etapa 1) Verifique as variáveis ​​contínuas

Na primeira etapa, você pode ver a distribuição das variáveis ​​contínuas.

continuous <-select_if(data_adult, is.numeric) summary(continuous)

Explicação do código

  • contínuo<- select_if(data_adult, is.numeric): Use the function select_if() from the dplyr library to select only the numerical columns
  • resumo (contínuo): Imprime a estatística de resumo

Saída:

## X age educational.num hours.per.week ## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00 ## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00 ## Median :23017 Median :37.00 Median :10.00 Median :40.00 ## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95 ## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00 ## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00 

Na tabela acima, você pode ver que os dados têm escalas totalmente diferentes e hours.per.weeks tem grandes outliers (ou seja, observe o último quartil e o valor máximo).

Você pode lidar com isso seguindo duas etapas:

  • 1: Trace a distribuição de horas.por.semana
  • 2: Padronizar as variáveis ​​contínuas
  1. Trace a distribuição

Vamos examinar mais de perto a distribuição de hours.per.week

 # Histogram with kernel density curve library(ggplot2) ggplot(continuous, aes(x = hours.per.week)) + geom_density(alpha = .2, fill = '#FF6666') 

Saída:

A variável tem muitos outliers e uma distribuição não bem definida. Você pode resolver parcialmente esse problema excluindo os primeiros 0,01% das horas semanais.

Sintaxe básica do quantil:

quantile(variable, percentile) arguments: -variable: Select the variable in the data frame to compute the percentile -percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C, ...) - `A`,`B`,`C` and `...` are all integer from 0 to 1.

Calculamos o percentil 2 por cento superior

 top_one_percent <- quantile(data_adult$hours.per.week, .99) top_one_percent 

Explicação do código

  • quantil (data_adult $ hours.per.week, .99): Calcule o valor de 99 por cento do tempo de trabalho

Saída:

## 99% ## 80 

98 por cento da população trabalha menos de 80 horas por semana.

Você pode diminuir as observações acima deste limite. Você usa o filtro da biblioteca dplyr.

 data_adult_drop % filter(hours.per.week

Saída:

## [1] 45537 10 
  1. Padronizar as variáveis ​​contínuas

Você pode padronizar cada coluna para melhorar o desempenho porque seus dados não têm a mesma escala. Você pode usar a função mutate_if da biblioteca dplyr. A sintaxe básica é:

mutate_if(df, condition, funs(function)) arguments: -`df`: Data frame used to compute the function - `condition`: Statement used. Do not use parenthesis - funs(function): Return the function to apply. Do not use parenthesis for the function

Você pode padronizar as colunas numéricas da seguinte maneira:

 data_adult_rescale % mutate_if(is.numeric, funs(as.numeric(scale(.)))) head(data_adult_rescale)

Explicação do código

  • mutate_if (is.numeric, funs (scale)): A condição é apenas coluna numérica e a função é escala

Saída:

 ## X age workclass education educational.num ## 1 -1.732680 -1.02325949 Private 11th -1.22106443 ## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868 ## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494 ## 4 -1.732455 0.41426100 Private Some-college -0.04945081 ## 5 -1.732379 -0.34232873 Private 10th -1.61160231 ## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857 ## marital.status race gender hours.per.week income ## 1 Never-married Black Male -0.03995944 <=50K ## 2 Married-civ-spouse White Male 0.86863037 50K ## 4 Married-civ-spouse Black Male -0.03995944>50K ## 5 Never-married White Male -0.94854924 50K 

Etapa 2) Verifique as variáveis ​​do fator

Esta etapa tem dois objetivos:

  • Verifique o nível em cada coluna categórica
  • Defina novos níveis

Vamos dividir esta etapa em três partes:

  • Selecione as colunas categóricas
  • Armazene o gráfico de barras de cada coluna em uma lista
  • Imprima os gráficos

Podemos selecionar as colunas do fator com o código abaixo:

# Select categorical column factor <- data.frame(select_if(data_adult_rescale, is.factor)) ncol(factor)

Explicação do código

  • data.frame (select_if (data_adult, is.factor)): Armazenamos as colunas de fator em fator em um tipo de frame de dados. A biblioteca ggplot2 requer um objeto de quadro de dados.

Saída:

## [1] 6 

O conjunto de dados contém 6 variáveis ​​categóricas

A segunda etapa é mais habilidosa. Você deseja traçar um gráfico de barras para cada coluna no fator de quadro de dados. É mais conveniente automatizar o processo, especialmente quando há muitas colunas.

 library(ggplot2) # Create graph for each column graph <- lapply(names(factor), function(x) ggplot(factor, aes(get(x))) + geom_bar() + theme(axis.text.x = element_text(angle = 90)))

Explicação do código

  • lapply (): Use a função lapply () para passar uma função em todas as colunas do conjunto de dados. Você armazena a saída em uma lista
  • função (x): A função será processada para cada x. Aqui x são as colunas
  • ggplot (factor, aes (get (x))) + geom_bar () + theme (axis.text.x = element_text (angle = 90)): Crie um gráfico de barras para cada elemento x. Observe, para retornar x como uma coluna, você precisa incluí-lo dentro de get ()

A última etapa é relativamente fácil. Você deseja imprimir os 6 gráficos.

 # Print the graph graph

Saída:

## [[1]]

## ## [[2]]

## ## [[3]]

## ## [[4]]

## ## [[5]]

## ## [[6]]

Nota: Use o próximo botão para navegar para o próximo gráfico

Etapa 3) Engenharia de recursos

Reforma da educação

No gráfico acima, você pode ver que a variável educação possui 16 níveis. Isso é substancial e alguns níveis têm um número relativamente baixo de observações. Se você quiser melhorar a quantidade de informações que pode obter dessa variável, pode reformulá-la para um nível superior. Ou seja, você cria grupos maiores com nível de educação semelhante. Por exemplo, baixo nível de educação será convertido em evasão. Os níveis mais elevados de educação serão alterados para mestre.

Aqui está o detalhe:

Nível antigo

Novo nível

Pré escola

cair fora

10º

Cair fora

11º

Cair fora

12º

Cair fora

1o ao 4o

Cair fora

5º a 6º

Cair fora

7º a 8º

Cair fora

Cair fora

HS-Grad

HighGrad

Alguma faculdade

Comunidade

Assoc-acdm

Comunidade

Assoc-voc

Comunidade

Solteiros

Solteiros

Mestres

Mestres

Prof-escola

Mestres

Doutorado

PhD

recast_data % select(-X) % > % mutate(education = factor(ifelse(education == 'Preschool' | education == '10th' | education == '11th' | education == '12th' | education == '1st-4th' | education == '5th-6th' | education == '7th-8th' | education == '9th', 'dropout', ifelse(education == 'HS-grad', 'HighGrad', ifelse(education == 'Some-college' | education == 'Assoc-acdm' | education == 'Assoc-voc', 'Community', ifelse(education == 'Bachelors', 'Bachelors', ifelse(education == 'Masters' | education == 'Prof-school', 'Master', 'PhD')))))))

Explicação do código

  • Usamos o verbo mutate da biblioteca dplyr. Mudamos os valores da educação com a declaração ifelse

Na tabela abaixo, você cria uma estatística resumida para ver, em média, quantos anos de escolaridade (valor z) leva para se atingir o Bacharelado, Mestrado ou Doutorado.

 recast_data % > % group_by(education) % > % summarize(average_educ_year = mean(educational.num), count = n()) % > % arrange(average_educ_year)

Saída:

 ## # A tibble: 6 x 3 ## education average_educ_year count ## ## 1 dropout -1.76147258 5712 ## 2 HighGrad -0.43998868 14803 ## 3 Community 0.09561361 13407 ## 4 Bachelors 1.12216282 7720 ## 5 Master 1.60337381 3338 ## 6 PhD 2.29377644 557 

Reformulação do estado civil

Também é possível criar níveis mais baixos para o estado civil. No código a seguir, você altera o nível da seguinte maneira:

Nível antigo

Novo nível

Nunca casado

Solteiro

Casado-cônjuge-ausente

Solteiro

Cônjuge casado com AF

Casado

Cônjuge casada


Separados

Separados

Divorciado


Viúvas

Viúva

 # Change level marry recast_data % mutate(marital.status = factor(ifelse(marital.status == 'Never-married' | marital.status == 'Married-spouse-absent', 'Not_married', ifelse(marital.status == 'Married-AF-spouse' | marital.status == 'Married-civ-spouse', 'Married', ifelse(marital.status == 'Separated' | marital.status == 'Divorced', 'Separated', 'Widow')))))
Você pode verificar o número de indivíduos em cada grupo.
table(recast_data$marital.status)

Saída:

 ## ## Married Not_married Separated Widow ## 21165 15359 7727 1286 

Etapa 4) Estatística resumida

É hora de verificar algumas estatísticas sobre nossas variáveis ​​de destino. No gráfico abaixo, você conta a porcentagem de indivíduos que ganham mais de 50 mil de acordo com seu gênero.

# Plot gender income ggplot(recast_data, aes(x = gender, fill = income)) + geom_bar(position = 'fill') + theme_classic()

Saída:

A seguir, verifique se a origem do indivíduo afeta seus ganhos.

# Plot origin income ggplot(recast_data, aes(x = race, fill = income)) + geom_bar(position = 'fill') + theme_classic() + theme(axis.text.x = element_text(angle = 90))

Saída:

O número de horas de trabalho por gênero.

 # box plot gender working time ggplot(recast_data, aes(x = gender, y = hours.per.week)) + geom_boxplot() + stat_summary(fun.y = mean, geom = 'point', size = 3, color = 'steelblue') + theme_classic()

Saída:

O gráfico de caixa confirma que a distribuição do tempo de trabalho se ajusta a grupos diferentes. No box plot, ambos os gêneros não apresentam observações homogêneas.

Você pode verificar a densidade do tempo de trabalho semanal por tipo de ensino. As distribuições têm muitas opções distintas. Provavelmente, isso pode ser explicado pelo tipo de contrato nos EUA.

# Plot distribution working time by education ggplot(recast_data, aes(x = hours.per.week)) + geom_density(aes(color = education), alpha = 0.5) + theme_classic()

Explicação do código

  • ggplot (recast_data, aes (x = hours.per.week)): um gráfico de densidade requer apenas uma variável
  • geom_density (aes (color = education), alpha = 0.5): O objeto geométrico para controlar a densidade

Saída:

Para confirmar seus pensamentos, você pode realizar um teste ANOVA unilateral:

 anova <- aov(hours.per.week~education, recast_data) summary(anova)

Saída:

 ## Df Sum Sq Mean Sq F value Pr(>F) ## education 5 1552 310.31 321.2 <2e-16 *** ## Residuals 45531 43984 0.97 ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 

O teste ANOVA confirma a diferença de média entre os grupos.

Não-linearidade

Antes de executar o modelo, você pode ver se o número de horas trabalhadas está relacionado à idade.

 library(ggplot2) ggplot(recast_data, aes(x = age, y = hours.per.week)) + geom_point(aes(color = income), size = 0.5) + stat_smooth(method = 'lm', formula = y~poly(x, 2), se = TRUE, aes(color = income)) + theme_classic() 

Explicação do código

  • ggplot (recast_data, aes (x = idade, y = horas.per.semana)): define a estética do gráfico
  • geom_point (aes (cor = renda), tamanho = 0,5): Construir o gráfico de pontos
  • stat_smooth (): Adicione a linha de tendência com os seguintes argumentos:
    • method = 'lm': Plote o valor ajustado se a regressão linear
    • formula = y ~ poly (x, 2): Ajustar uma regressão polinomial
    • se = TRUE: Adicione o erro padrão
    • aes (cor = renda): Divida o modelo por renda

Saída:

Em suma, você pode testar os termos de interação no modelo para captar o efeito da não linearidade entre o tempo de trabalho semanal e outros recursos. É importante detectar em que condições o tempo de trabalho difere.

Correlação

A próxima verificação é visualizar a correlação entre as variáveis. Você converte o tipo de nível de fator em numérico para que possa plotar um mapa de calor contendo o coeficiente de correlação calculado com o método de Spearman.

 library(GGally) # Convert data to numeric corr <- data.frame(lapply(recast_data, as.integer)) # Plot the graphggcorr(corr, method = c('pairwise', 'spearman'), nbreaks = 6, hjust = 0.8, label = TRUE, label_size = 3, color = 'grey50') 

Explicação do código

  • data.frame (lapply (recast_data, as.integer)): converter dados em numéricos
  • ggcorr () plota o mapa de calor com os seguintes argumentos:
    • método: Método para calcular a correlação
    • nbreaks = 6: Número de quebra
    • hjust = 0,8: posição de controle do nome da variável no gráfico
    • rótulo = TRUE: Adicionar rótulos no centro das janelas
    • label_size = 3: rótulos de tamanho
    • color = 'grey50'): Cor do rótulo

Saída:

Etapa 5) Conjunto de treinamento / teste

Qualquer tarefa de aprendizado de máquina supervisionada exige a divisão dos dados entre um conjunto de trens e um conjunto de teste. Você pode usar a 'função' criada nos outros tutoriais de aprendizagem supervisionada para criar um conjunto de treinamento / teste.

 set.seed(1234) create_train_test <- function(data, size = 0.8, train = TRUE) { n_row = nrow(data) total_row = size * n_row train_sample <- 1: total_row if (train == TRUE) { return (data[train_sample, ]) } else { return (data[-train_sample, ]) } } data_train <- create_train_test(recast_data, 0.8, train = TRUE) data_test <- create_train_test(recast_data, 0.8, train = FALSE) dim(data_train)

Saída:

## [1] 36429 9
dim(data_test)

Saída:

## [1] 9108 9 

Etapa 6) Construir o modelo

Para ver o desempenho do algoritmo, use o pacote glm (). o Modelo Linear Generalizado é uma coleção de modelos. A sintaxe básica é:

 glm(formula, data=data, family=linkfunction() Argument: - formula: Equation used to fit the model- data: dataset used - Family: - binomial: (link = 'logit') - gaussian: (link = 'identity') - Gamma: (link = 'inverse') - inverse.gaussian: (link = '1/mu^2') - poisson: (link = 'log') - quasi: (link = 'identity', variance = 'constant') - quasibinomial: (link = 'logit') - quasipoisson: (link = 'log') 

Você está pronto para estimar o modelo logístico para dividir o nível de renda entre um conjunto de recursos.

 formula <- income~. logit <- glm(formula, data = data_train, family = 'binomial') summary(logit) 

Explicação do código

  • Fórmula<- income ~ .: Create the model to fit
  • logit<- glm(formula, data = data_train, family = 'binomial'): Fit a logistic model (family = 'binomial') with the data_train data.
  • resumo (logit): Imprima o resumo do modelo

Saída:

 ## ## Call: ## glm(formula = formula, family = 'binomial', data = data_train) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -2.6456 -0.5858 -0.2609 -0.0651 3.1982 ## ## Coefficients: ## Estimate Std. Error z value Pr(>|z|) ## (Intercept) 0.07882 0.21726 0.363 0.71675 ## age 0.41119 0.01857 22.146 <2e-16 *** ## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 *** ## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 *** ## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499 ## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 *** ## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 *** ## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596 ## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 *** ## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 *** ## educationMaster 0.35651 0.06780 5.258 1.46e-07 *** ## educationPhD 0.46995 0.15772 2.980 0.00289 ** ## educationdropout -1.04974 0.21280 -4.933 8.10e-07 *** ## educational.num 0.56908 0.07063 8.057 7.84e-16 *** ## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 *** ## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 *** ## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 *** ## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117 ## raceBlack 0.07188 0.19330 0.372 0.71001 ## raceOther 0.01370 0.27695 0.049 0.96054 ## raceWhite 0.34830 0.18441 1.889 0.05894 . ## genderMale 0.08596 0.04289 2.004 0.04506 * ## hours.per.week 0.41942 0.01748 23.998 < 2e-16 *** ## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for binomial family taken to be 1) ## ## Null deviance: 40601 on 36428 degrees of freedom ## Residual deviance: 27041 on 36406 degrees of freedom ## AIC: 27087 ## ## Number of Fisher Scoring iterations: 6 

O resumo do nosso modelo revela informações interessantes. O desempenho de uma regressão logística é avaliado com métricas chave específicas.

  • AIC (Akaike Information Criteria): Este é o equivalente a R2 em regressão logística. Ele mede o ajuste quando uma penalidade é aplicada ao número de parâmetros. Menor AIC os valores indicam que o modelo está mais perto da verdade.
  • Desvio nulo: ajusta o modelo apenas com a interceptação. O grau de liberdade é n-1. Podemos interpretá-lo como um valor Qui-quadrado (valor ajustado diferente do teste de hipótese do valor real).
  • Desvio residual: Modelo com todas as variáveis. Também é interpretado como um teste de hipótese do qui-quadrado.
  • Número de iterações do Fisher Scoring: Número de iterações antes da convergência.

A saída da função glm () é armazenada em uma lista. O código a seguir mostra todos os itens disponíveis na variável logit que construímos para avaliar a regressão logística.

# A lista é muito longa, imprima apenas os três primeiros elementos

lapply(logit, class)[1:3]

Saída:

 ## $coefficients ## [1] 'numeric' ## ## $residuals ## [1] 'numeric' ## ## $fitted.values ## [1] 'numeric' 

Cada valor pode ser extraído com o sinal $ seguido do nome das métricas. Por exemplo, você armazenou o modelo como logit. Para extrair os critérios AIC, você usa:

logit$aic

Saída:

## [1] 27086.65

Etapa 7) Avalie o desempenho do modelo

Matriz de confusão

o matriz de confusão é a melhor escolha para avaliar o desempenho da classificação em comparação com as diferentes métricas que você viu antes. A ideia geral é contar quantas vezes as instâncias True são classificadas como Falsas.

Para calcular a matriz de confusão, primeiro você precisa ter um conjunto de previsões para que possam ser comparadas aos alvos reais.

 predict <- predict(logit, data_test, type = 'response') # confusion matrix table_mat 0.5) table_mat 

Explicação do código

  • predizer (logit, data_test, type = 'resposta'): Calcula a predição no conjunto de teste. Defina type = 'response' para calcular a probabilidade de resposta.
  • tabela (data_test $ renda, previsão> 0,5): Calcula a matriz de confusão. predizer> 0,5 significa que retorna 1 se as probabilidades previstas estiverem acima de 0,5, caso contrário, 0.

Saída:

 ## ## FALSE TRUE ## 50K 1074 1229 

Cada linha em uma matriz de confusão representa um alvo real, enquanto cada coluna representa um alvo previsto. A primeira linha desta matriz considera a renda inferior a 50k (a classe Falsa): 6241 foram classificados corretamente como indivíduos com renda inferior a 50k ( Verdadeiro negativo ), enquanto o restante foi classificado incorretamente como acima de 50k ( Falso positivo ) A segunda linha considera a receita acima de 50k, a classe positiva foi 1229 ( Verdadeiro positivo ), enquanto o Verdadeiro negativo era 1074.

Você pode calcular o modelo precisão somando o verdadeiro positivo + verdadeiro negativo sobre a observação total

 accuracy_Test <- sum(diag(table_mat)) / sum(table_mat) accuracy_Test

Explicação do código

  • soma (diag (table_mat)): Soma da diagonal
  • sum (table_mat): Soma da matriz.

Saída:

## [1] 0.8277339 

O modelo parece ter um problema: ele superestima o número de falsos negativos. Isso é chamado de Paradoxo do teste de precisão . Declaramos que a precisão é a razão entre as previsões corretas e o número total de casos. Podemos ter uma precisão relativamente alta, mas um modelo inútil. Acontece quando existe uma classe dominante. Se você olhar novamente para a matriz de confusão, verá que a maioria dos casos são classificados como negativos verdadeiros. Imagine agora, o modelo classificou todas as classes como negativas (ou seja, abaixo de 50k). Você teria uma precisão de 75 por cento (6718/6718 + 2257). Seu modelo tem um desempenho melhor, mas se esforça para distinguir o verdadeiro positivo do verdadeiro negativo.

Nessa situação, é preferível ter uma métrica mais concisa. Podemos olhar para:

  • Precisão = TP / (TP + FP)
  • Rechamada = TP / (TP + FN)

Precisão vs recall

Precisão analisa a precisão da previsão positiva. Lembrar é a proporção de instâncias positivas que são detectadas corretamente pelo classificador;

Você pode construir duas funções para calcular essas duas métricas

  1. Precisão de construção
 precision <- function(matrix) { # True positive tp <- matrix[2, 2] # false positive fp <- matrix[1, 2] return (tp / (tp + fp)) } 

Explicação do código

  • mat [1,1]: Retorna a primeira célula da primeira coluna do quadro de dados, ou seja, o verdadeiro positivo
  • esteira [1,2]; Retorna a primeira célula da segunda coluna do quadro de dados, ou seja, o falso positivo
 recall <- function(matrix) { # true positive tp <- matrix[2, 2]# false positive fn <- matrix[2, 1] return (tp / (tp + fn)) } 

Explicação do código

  • mat [1,1]: Retorna a primeira célula da primeira coluna do quadro de dados, ou seja, o verdadeiro positivo
  • tapete [2,1]; Retorna a segunda célula da primeira coluna do quadro de dados, ou seja, o falso negativo

Você pode testar suas funções

 prec <- precision(table_mat) prec rec <- recall(table_mat) rec 

Saída:

 ## [1] 0.712877 ## [2] 0.5336518

Quando o modelo diz que é um indivíduo acima de 50k, ele está correto em apenas 54% dos casos e pode reivindicar indivíduos acima de 50k em 72% dos casos.

Você pode criar o pontuação com base na precisão e recall. o é uma média harmônica dessas duas métricas, o que significa que dá mais peso aos valores mais baixos.

 f1 <- 2 * ((prec * rec) / (prec + rec)) f1 

Saída:

## [1] 0.6103799 

Troca de precisão vs recall

É impossível ter alta precisão e alto recall.

Se aumentarmos a precisão, o indivíduo correto será melhor previsto, mas perderíamos muitos deles (menor recall). Em algumas situações, preferimos maior precisão do que recall. Existe uma relação côncava entre precisão e recall.

  • Imagine, você precisa prever se um paciente tem uma doença. Você quer ser o mais preciso possível.
  • Se você precisar detectar pessoas potencialmente fraudulentas na rua por meio do reconhecimento facial, seria melhor detectar muitas pessoas rotuladas como fraudulentas, mesmo que a precisão seja baixa. A polícia poderá libertar o indivíduo não fraudulento.

A curva ROC

o Características operacionais do receptor curva é outra ferramenta comum usada com classificação binária. É muito semelhante à curva de precisão / rechamada, mas em vez de representar a precisão versus rechamada, a curva ROC mostra a taxa de verdadeiro positivo (isto é, rechamada) contra a taxa de falso positivo. A taxa de falsos positivos é a proporção de instâncias negativas que são classificadas incorretamente como positivas. É igual a um menos a taxa negativa verdadeira. A verdadeira taxa negativa também é chamada especificidade . Portanto, a curva ROC traça sensibilidade (lembre-se) versus especificidade 1

Para plotar a curva ROC, precisamos instalar uma biblioteca chamada RORC. Podemos encontrar no conda biblioteca . Você pode digitar o código:

conda install -c r r-rocr - sim

Podemos plotar o ROC com as funções prediction () e performance ().

 library(ROCR) ROCRpred <- prediction(predict, data_test$income) ROCRperf <- performance(ROCRpred, 'tpr', 'fpr') plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7)) 

Explicação do código

  • predição (predição, teste de dados $ renda): a biblioteca ROCR precisa criar um objeto de predição para transformar os dados de entrada
  • performance (ROCRpred, 'tpr', 'fpr'): Retorne as duas combinações para produzir no gráfico. Aqui, tpr e fpr são construídos. Para plotar a precisão e chamar juntos, use 'prec', 'rec'.

Saída:

Etapa 8) Melhore o modelo

Você pode tentar adicionar não linearidade ao modelo com a interação entre

  • idade e horas.por.semana
  • gênero e horas.por.semana.

Você precisa usar o teste de pontuação para comparar os dois modelos

 formula_2 <- income~age: hours.per.week + gender: hours.per.week + . logit_2 <- glm(formula_2, data = data_train, family = 'binomial') predict_2 <- predict(logit_2, data_test, type = 'response') table_mat_2 0.5) precision_2 <- precision(table_mat_2) recall_2 <- recall(table_mat_2) f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2)) f1_2 

Saída:

## [1] 0.6109181 

A pontuação é ligeiramente superior à anterior. Você pode continuar trabalhando nos dados e tentar bater a pontuação.

Resumo

Podemos resumir a função para treinar uma regressão logística na tabela abaixo:

Pacote

Objetivo

função

argumento

-

Criar conjunto de dados de treinamento / teste

create_train_set ()

dados, tamanho, trem

glm

Treine um modelo linear generalizado

glm ()

fórmula, dados, família *

glm

Resuma o modelo

resumo()

modelo ajustado

base

Fazer previsão

prever()

modelo ajustado, conjunto de dados, tipo = 'resposta'

base

Crie uma matriz de confusão

tabela()

e prever ()

base

Criar pontuação de precisão

soma (diag (tabela ()) / soma (tabela ()

ROCR

Criar ROC: Etapa 1 Criar previsão

predição()

predizer (), e

ROCR

Criar ROC: Etapa 2 Criar desempenho

atuação()

predição (), 'tpr', 'fpr'

ROCR

Criar ROC: Etapa 3 Gráfico de plotagem

enredo()

atuação()

O outro GLM tipos de modelos são:

- binomial: (link = 'logit')

- gaussian: (link = 'identidade')

- Gama: (link = 'inverso')

- inverse.gaussian: (link = '1 / mu ^ 2')

- poisson: (link = 'log')

- quase: (link = 'identidade', variância = 'constante')

- quasibinomial: (link = 'logit')

- quasipoisson: (link = 'log')