大数跨境
0
0

【定量因变量机器学习】R语言03-决策树回归(Decision Tree Regression) 标准化代码

【定量因变量机器学习】R语言03-决策树回归(Decision Tree Regression) 标准化代码 医学统计数据分析
2025-11-27
1
导读:定量因变量机器学习03.决策树回归(Decision Tree Regression)R语言教程(标准化代码)




定量因变量机器学习



03.决策树回归(Decision Tree Regression)


R语言教程(标准化代码)

01

概念、原理、思想、应用

概念:决策树回归通过树状结构进行回归预测,每个内部节点表示一个特征测试,每个叶节点表示一个预测值。

原理:通过递归地划分特征空间,使得每个子空间内的样本尽可能同质。常用划分准则有均方误差(MSE)等。

思想:将特征空间划分为多个矩形区域,并在每个区域上使用常量进行预测。

应用:非线性关系建模、可解释性要求较高的场景。

02

操作流程

-数据预处理:

-模型构建:

-训练:

-评估:

-可视化:

-保存结果:


03

代码及操作演示与功能解析

定量变量预测的机器学习模型可分为传统统计模型、树基集成模型、核方法和深度学习模型四大类,每类模型通过不同机制捕捉数据模式,适用于从线性到复杂非线性关系的预测任务。

代码涵盖了从数据准备到结果保存的自动化过程,包括数据预处理、模型配置、性能评估和报告生成。

# 设置工作目录和清理环境
rm(list = ls())
if (!is.null(dev.list())) dev.off()
setwd("C:/Users/hyy/Desktop/")
if (!dir.exists("Results-DecisionTree-ML")) dir.create("Results-DecisionTree-ML")

# 加载必要的包
if (!require(pacman)) install.packages("pacman")
pacman::p_load(
  readxl, dplyr, ggplot2, car, lmtest, ggpubr, corrplot, 
  performance, see, effects, sjPlot, report, officer, flextable,
  broom, gridExtra, patchwork, rpart, rpart.plot, party, 
  caret, randomForest, partykit, vip, pROC, MLmetrics, ModelMetrics
)

# 设置中文字体(使用系统默认字体避免问题)
font_family <- "sans"

# 读取数据
data <- read_excel("示例数据.xlsx", sheet = "示例数据")

# 数据预处理
str(data)
summary(data)

# 处理分类变量
data$结局 <- as.factor(data$结局)
data$肥胖程度 <- as.factor(data$肥胖程度)
data$教育水平 <- as.factor(data$教育水平)
data$血型 <- as.factor(data$血型)
data$指标8 <- as.factor(data$指标8)

# 检查缺失值
sum(is.na(data))

# 机器学习数据划分
set.seed(123)
train_index <- createDataPartition(data$指标1, p = 0.7, list = FALSE)
train_data <- data[train_index, ]
test_data <- data[-train_index, ]

cat("训练集样本数:", nrow(train_data), "\n")
cat("测试集样本数:", nrow(test_data), "\n")


# 模型预测 - 训练集和测试集
y_pred_train <- predict(tree_model, train_data)
y_pred_test <- predict(tree_model, test_data)

# 计算训练集性能指标
mse_train <- mean((train_data$指标1 - y_pred_train)^2)
rmse_train <- sqrt(mse_train)
mae_train <- mean(abs(train_data$指标1 - y_pred_train))
r_squared_train <- 1 - sum((train_data$指标1 - y_pred_train)^2) / 
  sum((train_data$指标1 - mean(train_data$指标1))^2)

# 计算测试集性能指标
mse_test <- mean((test_data$指标1 - y_pred_test)^2)
rmse_test <- sqrt(mse_test)
mae_test <- mean(abs(test_data$指标1 - y_pred_test))
r_squared_test <- 1 - sum((test_data$指标1 - y_pred_test)^2) / 
  sum((test_data$指标1 - mean(test_data$指标1))^2)

# 交叉验证性能评估(在训练集上)
set.seed(123)
train_control <- trainControl(method = "cv", number = 5)
cv_tree <- train(指标1 ~ . - 序号, 
                 data = train_data, 
                 method = "rpart",
                 trControl = train_control,
                 tuneLength = 5)

# 保存模型结果
model_performance <- data.frame(
  Dataset = c("Training""Training""Training""Training"
              "Testing""Testing""Testing""Testing""CV"),
  Metric = c("MSE""RMSE""MAE""R-squared"
             "MSE""RMSE""MAE""R-squared""CV-RMSE"),
  Value = c(mse_train, rmse_train, mae_train, r_squared_train,
            mse_test, rmse_test, mae_test, r_squared_test,
            min(cv_tree$results$RMSE, na.rm = TRUE))
)

# 变量重要性
if(!is.null(tree_model$variable.importance)) {
  var_importance <- data.frame(
    Variable = names(tree_model$variable.importance),
    Importance = as.numeric(tree_model$variable.importance)
  ) %>% arrange(desc(Importance))
else {
  var_importance <- data.frame(
    Variable = "结局",
    Importance = 1
  )
}

# 保存结果
write.csv(model_performance, "Results-DecisionTree-ML/决策树模型性能_ML.csv", row.names = FALSE, fileEncoding = "UTF-8")
write.csv(var_importance, "Results-DecisionTree-ML/决策树变量重要性.csv", row.names = FALSE, fileEncoding = "UTF-8")

# 1. 决策树可视化
p_tree_basic <- function() {
  rpart.plot(tree_model, 
             type = 2, 
             extra = 101, 
             fallen.leaves = FALSE,
             main = "Decision Tree Structure (Trained on Training Set)",
             box.palette = "Blues")
}

p_tree_fancy <- function() {
  prp(tree_model,
      type = 2,
      extra = 101,
      nn = TRUE,
      fallen.leaves = TRUE,
      branch = 0.5,
      faclen = 0,
      varlen = 0,
      shadow.col = "gray",
      box.col = "lightblue")
}

# 2. 变量重要性可视化
p_var_importance <- ggplot(var_importance, 
                           aes(x = Importance, y = reorder(Variable, Importance))) +
  geom_col(fill = "steelblue", alpha = 0.8) +
  labs(title = "Decision Tree Variable Importance",
       x = "Importance Score"
       y = "Variables") +
  theme_minimal() +
  theme(plot.title = element_text(hjust = 0.5, size = 14, face = "bold"))

# 3. 训练集和测试集预测效果可视化
# 训练集
pred_actual_train <- data.frame(
  Dataset = "Training",
  Actual = train_data$指标1,
  Predicted = y_pred_train
)

# 测试集
pred_actual_test <- data.frame(
  Dataset = "Testing",
  Actual = test_data$指标1,
  Predicted = y_pred_test
)

pred_actual_combined <- rbind(pred_actual_train, pred_actual_test)

p_pred_vs_actual <- ggplot(pred_actual_combined, aes(x = Actual, y = Predicted, color = Dataset)) +
  geom_point(alpha = 0.6) +
  geom_abline(intercept = 0, slope = 1, color = "red", linetype = "dashed") +
  labs(title = "Predicted vs Actual Values - Training vs Testing",
       x = "Actual Values"
       y = "Predicted Values") +
  theme_minimal() +
  scale_color_manual(values = c("Training" = "blue""Testing" = "green")) +
  theme(legend.position = "bottom")

# 4. 残差分析 - 训练集和测试集
# 训练集残差
residuals_train <- data.frame(
  Dataset = "Training",
  Fitted = y_pred_train,
  Residuals = train_data$指标1 - y_pred_train
)

# 测试集残差
residuals_test <- data.frame(
  Dataset = "Testing"
  Fitted = y_pred_test,
  Residuals = test_data$指标1 - y_pred_test
)

residuals_combined <- rbind(residuals_train, residuals_test)

p_residual <- ggplot(residuals_combined, aes(x = Fitted, y = Residuals, color = Dataset)) +
  geom_point(alpha = 0.6) +
  geom_hline(yintercept = 0, linetype = "dashed", color = "red") +
  labs(title = "Residual Plot - Training vs Testing"
       x = "Fitted Values"
       y = "Residuals") +
  theme_minimal() +
  scale_color_manual(values = c("Training" = "blue""Testing" = "green"))

# Q-Q图 - 训练集和测试集
p_qq_train <- ggplot(residuals_train, aes(sample = Residuals)) +
  stat_qq() + 
  stat_qq_line() +
  labs(title = "Q-Q Plot of Training Set Residuals") +
  theme_minimal()

p_qq_test <- ggplot(residuals_test, aes(sample = Residuals)) +
  stat_qq() + 
  stat_qq_line() +
  labs(title = "Q-Q Plot of Testing Set Residuals") +
  theme_minimal()

# 残差分布 - 训练集和测试集
p_resid_hist_train <- ggplot(residuals_train, aes(x = Residuals)) +
  geom_histogram(aes(y = after_stat(density)), bins = 30, fill = "lightblue", alpha = 0.7) +
  geom_density(color = "blue") +
  labs(title = "Distribution of Training Set Residuals"
       x = "Residuals"
       y = "Density") +
  theme_minimal()

p_resid_hist_test <- ggplot(residuals_test, aes(x = Residuals)) +
  geom_histogram(aes(y = after_stat(density)), bins = 30, fill = "lightgreen", alpha = 0.7) +
  geom_density(color = "darkgreen") +
  labs(title = "Distribution of Testing Set Residuals"
       x = "Residuals"
       y = "Density") +
  theme_minimal()

# 5. 复杂度参数图
if(!is.null(tree_model$cptable)) {
  cp_data <- as.data.frame(tree_model$cptable)
  p_cp <- ggplot(cp_data, aes(x = nsplit + 1, y = xerror)) +
    geom_line(color = "steelblue", size = 1) +
    geom_point(color = "red", size = 2) +
    geom_errorbar(aes(ymin = xerror - xstd, ymax = xerror + xstd), 
                  width = 0.2, color = "darkgray") +
    labs(title = "Decision Tree Complexity Parameter Analysis",
         x = "Number of Splits",
         y = "Cross-Validation Error") +
    theme_minimal()
else {
  p_cp <- ggplot() + 
    labs(title = "Complexity Parameter Analysis Not Available") +
    theme_minimal()
}


# 6. 性能对比图
performance_comparison <- data.frame(
  Metric = rep(c("RMSE""MAE""R-squared"), 2),
  Value = c(rmse_train, mae_train, r_squared_train, 
            rmse_test, mae_test, r_squared_test),
  Dataset = rep(c("Training""Testing"), each = 3)
)

p_performance <- ggplot(performance_comparison, aes(x = Metric, y = Value, fill = Dataset)) +
  geom_bar(stat = "identity", position = "dodge", alpha = 0.8) +
  labs(title = "Model Performance Comparison: Training vs Testing",
       x = "Performance Metrics",
       y = "Value") +
  theme_minimal() +
  scale_fill_manual(values = c("Training" = "steelblue""Testing" = "orange")) +
  theme(legend.position = "bottom")

# 7. 学习曲线分析(训练集大小 vs 性能)
learning_curve_data <- data.frame()
train_sizes <- seq(0.1, 0.9, by = 0.1)

for (size in train_sizes) {
  set.seed(123)
  small_train_index <- createDataPartition(train_data$指标1, p = size, list = FALSE)
  small_train <- train_data[small_train_index, ]

# 训练小模型
  small_tree <- rpart(指标1 ~ . - 序号, data = small_train, method = "anova")

# 训练集性能
  small_pred_train <- predict(small_tree, small_train)
  rmse_small_train <- sqrt(mean((small_train$指标1 - small_pred_train)^2))

# 测试集性能
  small_pred_test <- predict(small_tree, test_data)
  rmse_small_test <- sqrt(mean((test_data$指标1 - small_pred_test)^2))

  learning_curve_data <- rbind(learning_curve_data, 
                               data.frame(
                                 TrainSize = size * nrow(train_data),
                                 RMSE_Train = rmse_small_train,
                                 RMSE_Test = rmse_small_test
                               ))
}

p_learning_curve <- ggplot(learning_curve_data, aes(x = TrainSize)) +
  geom_line(aes(y = RMSE_Train, color = "Training"), size = 1) +
  geom_line(aes(y = RMSE_Test, color = "Testing"), size = 1) +
  labs(title = "Learning Curve: Training Size vs RMSE",
       x = "Training Set Size",
       y = "RMSE",
       color = "Dataset") +
  theme_minimal() +
  scale_color_manual(values = c("Training" = "blue""Testing" = "red"))

# 8. 模型稳定性分析 - 多次随机划分的性能分布
stability_results <- data.frame()
n_iterations <- 10

for (i in 1:n_iterations) {
  set.seed(100 + i)
  stability_index <- createDataPartition(data$指标1, p = 0.7, list = FALSE)
  stability_train <- data[stability_index, ]
  stability_test <- data[-stability_index, ]

  stability_tree <- rpart(指标1 ~ . - 序号, data = stability_train, method = "anova")
  stability_pred <- predict(stability_tree, stability_test)

  stability_rmse <- sqrt(mean((stability_test$指标1 - stability_pred)^2))
  stability_r2 <- 1 - sum((stability_test$指标1 - stability_pred)^2) / 
    sum((stability_test$指标1 - mean(stability_test$指标1))^2)

  stability_results <- rbind(stability_results,
                             data.frame(
                               Iteration = i,
                               RMSE = stability_rmse,
                               R_squared = stability_r2
                             ))
}

p_stability <- ggplot(stability_results, aes(x = Iteration, y = RMSE)) +
  geom_line(color = "steelblue", size = 1) +
  geom_point(color = "red", size = 2) +
  labs(title = "Model Stability Across Different Data Splits",
       x = "Iteration",
       y = "Test RMSE") +
  theme_minimal()

# 保存所有图形

# 决策树结构图
tryCatch({
  jpeg("Results-DecisionTree-ML/decision_tree_basic.jpg", width = 12, height = 10, units = "in", res = 300)
  p_tree_basic()
  dev.off()
}, error = function(e) {
  cat("Error saving decision_tree_basic.jpg:", e$message"\n")
})

tryCatch({
  jpeg("Results-DecisionTree-ML/decision_tree_fancy.jpg", width = 12, height = 10, units = "in", res = 300)
  p_tree_fancy()
  dev.off()
}, error = function(e) {
  cat("Error saving decision_tree_fancy.jpg:", e$message"\n")
})

# 使用ggsave保存ggplot图形
ggsave("Results-DecisionTree-ML/variable_importance.jpg", p_var_importance, 
       width = 10, height = 8, units = "in", dpi = 300)
ggsave("Results-DecisionTree-ML/predicted_vs_actual.jpg", p_pred_vs_actual, 
       width = 10, height = 8, units = "in", dpi = 300)
ggsave("Results-DecisionTree-ML/residual_plot.jpg", p_residual, 
       width = 10, height = 8, units = "in", dpi = 300)
ggsave("Results-DecisionTree-ML/qq_plot_train.jpg", p_qq_train, 
       width = 10, height = 8, units = "in", dpi = 300)
ggsave("Results-DecisionTree-ML/qq_plot_test.jpg", p_qq_test, 
       width = 10, height = 8, units = "in", dpi = 300)
ggsave("Results-DecisionTree-ML/residual_distribution_train.jpg", p_resid_hist_train, 
       width = 10, height = 8, units = "in", dpi = 300)
ggsave("Results-DecisionTree-ML/residual_distribution_test.jpg", p_resid_hist_test, 
       width = 10, height = 8, units = "in", dpi = 300)
ggsave("Results-DecisionTree-ML/complexity_parameter.jpg", p_cp, 
       width = 10, height = 8, units = "in", dpi = 300)
ggsave("Results-DecisionTree-ML/performance_comparison.jpg", p_performance, 
       width = 10, height = 8, units = "in", dpi = 300)
ggsave("Results-DecisionTree-ML/learning_curve.jpg", p_learning_curve, 
       width = 10, height = 8, units = "in", dpi = 300)
ggsave("Results-DecisionTree-ML/model_stability.jpg", p_stability, 
       width = 10, height = 8, units = "in", dpi = 300)


# 生成详细的机器学习评估报告
tryCatch({
  doc <- read_docx()

# 标题
  doc <- doc %>% 
    body_add_par("Decision Tree Regression - Machine Learning Analysis Report", style = "heading 1") %>%
    body_add_par(paste("Date:", Sys.Date()), style = "Normal") %>%
    body_add_par("", style = "Normal")

# 数据概述
  doc <- doc %>% 
    body_add_par("Data Overview", style = "heading 2") %>%
    body_add_par(paste("Total samples:", nrow(data)), style = "Normal") %>%
    body_add_par(paste("Training set samples:", nrow(train_data)), style = "Normal") %>%
    body_add_par(paste("Testing set samples:", nrow(test_data)), style = "Normal") %>%
    body_add_par(paste("Number of variables:", ncol(data)), style = "Normal") %>%
    body_add_par("", style = "Normal")

# 模型性能摘要
  doc <- doc %>% 
    body_add_par("Model Performance Summary", style = "heading 2")

  perf_table <- model_performance %>%
    flextable() %>%
    theme_zebra() %>%
    autofit()

  doc <- doc %>% 
    body_add_flextable(perf_table) %>%
    body_add_par("", style = "Normal")

# 过拟合分析
  overfitting_gap <- rmse_test - rmse_train
  doc <- doc %>% 
    body_add_par("Overfitting Analysis", style = "heading 2") %>%
    body_add_par(paste("RMSE gap (Test - Train):", round(overfitting_gap, 4)), style = "Normal") %>%
    body_add_par(ifelse(overfitting_gap > 0.1, 
                        "Warning: Potential overfitting detected (large gap between train and test performance)",
                        "Good: Model shows good generalization (small gap between train and test performance)"), 
                 style = "Normal") %>%
    body_add_par("", style = "Normal")

# 变量重要性
  doc <- doc %>% 
    body_add_par("Variable Importance", style = "heading 2")

  importance_ft <- var_importance %>%
    flextable() %>%
    theme_zebra() %>%
    autofit()

  doc <- doc %>% 
    body_add_flextable(importance_ft) %>%
    body_add_par("", style = "Normal")

# 稳定性分析
  stability_summary <- data.frame(
    Metric = c("Mean Test RMSE""Std Test RMSE""Mean Test R-squared""Std Test R-squared"),
    Value = c(mean(stability_results$RMSE), sd(stability_results$RMSE),
              mean(stability_results$R_squared), sd(stability_results$R_squared))
  )

  doc <- doc %>% 
    body_add_par("Model Stability Analysis", style = "heading 2")

  stability_ft <- stability_summary %>%
    flextable() %>%
    theme_zebra() %>%
    autofit()

  doc <- doc %>% 
    body_add_flextable(stability_ft) %>%
    body_add_par("", style = "Normal")

# 添加图片到报告
  doc <- doc %>% 
    body_add_par("Visualization Results", style = "heading 2") %>%
    body_add_par("Decision tree structure:", style = "Normal") %>%
    body_add_img("Results-DecisionTree-ML/decision_tree_basic.jpg", width = 7, height = 6) %>%
    body_add_par("Performance comparison:", style = "Normal") %>%
    body_add_img("Results-DecisionTree-ML/performance_comparison.jpg", width = 6, height = 5) %>%
    body_add_par("Learning curve:", style = "Normal") %>%
    body_add_img("Results-DecisionTree-ML/learning_curve.jpg", width = 6, height = 5) %>%
    body_add_par("Model stability:", style = "Normal") %>%
    body_add_img("Results-DecisionTree-ML/model_stability.jpg", width = 6, height = 5) %>%
    body_add_par("Predicted vs actual values:", style = "Normal") %>%
    body_add_img("Results-DecisionTree-ML/predicted_vs_actual.jpg", width = 6, height = 5)

# 结论和建议
  doc <- doc %>% 
    body_add_par("Conclusion and Recommendations", style = "heading 2") %>%
    body_add_par("Based on machine learning analysis of decision tree regression:", style = "Normal") %>%
    body_add_par(paste("- Training R-squared:", round(r_squared_train * 100, 2), "%"), style = "Normal") %>%
    body_add_par(paste("- Testing R-squared:", round(r_squared_test * 100, 2), "%"), style = "Normal") %>%
    body_add_par(paste("- Model generalization gap:", round(overfitting_gap, 4)), style = "Normal") %>%
    body_add_par(paste("- Most important variables:", paste(head(var_importance$Variable, 3), collapse = ", ")), style = "Normal") %>%
    body_add_par("", style = "Normal") %>%
    body_add_par("Recommendations:", style = "Normal") %>%
    body_add_par("1. Consider hyperparameter tuning if overfitting is detected", style = "Normal") %>%
    body_add_par("2. Explore ensemble methods (Random Forest, Gradient Boosting) for improved performance", style = "Normal") %>%
    body_add_par("3. Monitor model performance on new data to detect concept drift", style = "Normal")

# 保存Word文档
print(doc, target = "Results-DecisionTree-ML/DecisionTree_ML_Analysis_Report.docx")
}, error = function(e) {
  cat("Error generating Word report:", e$message"\n")
})

# 保存R工作环境
save.image("Results-DecisionTree-ML/DecisionTree_ML_Analysis.RData")

# 保存模型对象
saveRDS(tree_model, "Results-DecisionTree-ML/decision_tree_model.rds")
saveRDS(cv_tree, "Results-DecisionTree-ML/cross_validated_tree.rds")

# 输出完成信息
cat("Decision tree machine learning analysis completed!\n")
cat("Results saved to Results-DecisionTree-ML folder:\n")
cat("- Model performance comparison (training vs testing)\n"
cat("- Comprehensive visualizations including learning curves\n")
cat("- Model stability analysis\n")
cat("- Detailed Word analysis report\n")
cat("- R workspace and model objects\n")



 3. 决策树回归(Decision Tree Regression)

 概念:基于树状结构的非参数回归方法。

 原理

- 分裂准则:最小化子节点内方差

- 停止条件:最大深度、最小样本数、最小纯度提升

- 预测:叶节点内样本的平均值

 思想:递归地将特征空间划分为矩形区域,每个区域用常量预测。




医学统计数据分析分享交流SPSS、R语言、Python、ArcGis、Geoda、GraphPad、数据分析图表制作等心得。承接数据分析,论文返修,医学统计,机器学习,生存分析,空间分析,问卷分析业务。若有投稿和数据分析代做需求,可以直接联系我,谢谢!



!!!可加我粉丝群!!!

“医学统计数据分析”公众号右下角;

找到“联系作者”,

可加我微信,邀请入粉丝群!

【医学统计数据分析】工作室“粉丝群”

01

【临床】粉丝群

有临床流行病学数据分析

如(t检验、方差分析、χ2检验、logistic回归)、

(重复测量方差分析与配对T检验、ROC曲线)、

(非参数检验、生存分析、样本含量估计)、

(筛检试验:灵敏度、特异度、约登指数等计算)、

(绘制柱状图、散点图、小提琴图、列线图等)、

机器学习、深度学习、生存分析

等需求的同仁们,加入【临床】粉丝群

02

【公卫】粉丝群

疾控,公卫岗位的同仁,可以加一下【公卫】粉丝群,分享生态学研究、空间分析、时间序列、监测数据分析、时空面板技巧等工作科研自动化内容。

03

【生信】粉丝群

有实验室数据分析需求的同仁们,可以加入【生信】粉丝群,交流NCBI(基因序列)、UniProt(蛋白质)、KEGG(通路)、GEO(公共数据集)等公共数据库、基因组学转录组学蛋白组学代谢组学表型组学等数据分析和可视化内容。



或者可扫码直接加微信进群!!!





精品视频课程-“医学统计数据分析”视频号付费合集

“医学统计数据分析”视频号-付费合集兑换相应课程后,获取课程理论课PPT、代码、基础数据等相关资料,请大家在【医学统计数据分析】公众号右下角,找到“联系作者”,加我微信后打包发送。感谢您的支持!!



【声明】内容源于网络
0
0
医学统计数据分析
分享交流SPSS、R语言、Python、ArcGis、Geoda、GraphPad、数据分析图表制作等心得。承接数据分析,论文返修,医学统计,空间分析,机器学习,生存分析,时间序列,时空面板,深度学习,问卷分析等业务。公众号右下角可联系作者
内容 323
粉丝 0
医学统计数据分析 分享交流SPSS、R语言、Python、ArcGis、Geoda、GraphPad、数据分析图表制作等心得。承接数据分析,论文返修,医学统计,空间分析,机器学习,生存分析,时间序列,时空面板,深度学习,问卷分析等业务。公众号右下角可联系作者
总阅读449
粉丝0
内容323