BAWTO

Tree Partitioning Survival Models

This notebook relies on clinDF and biomDF which are computed in the analysis notebook. Those values are loaded here. Everything else is computed in this notebook.

library(tidyverse) 
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✔ ggplot2 3.3.5     ✔ purrr   0.3.4
## ✔ tibble  3.1.6     ✔ dplyr   1.0.8
## ✔ tidyr   1.2.0     ✔ stringr 1.4.0
## ✔ readr   2.1.2     ✔ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag()    masks stats::lag()
library(reshape2) # melt (should replace with pivot_longer)
## 
## Attaching package: 'reshape2'
## The following object is masked from 'package:tidyr':
## 
##     smiths
library(partykit) # model based partitioning
## Loading required package: grid
## Loading required package: libcoin
## Loading required package: mvtnorm
library(survival) # survival analysis
library(ggparty)  # plot model based partitioning
library(ggkm)     # plot kaplan-meier curves devtools::install_github("sachsmc/ggkm")
## Loading required package: scales
## 
## Attaching package: 'scales'
## The following object is masked from 'package:purrr':
## 
##     discard
## The following object is masked from 'package:readr':
## 
##     col_factor
library(patchwork)

clinDF <- readRDS(url("http://insilica.co.s3.amazonaws.com/bawto/resources/clindf.rds"))
biomDF <- readRDS(url("http://insilica.co.s3.amazonaws.com/bawto/resources/biomdf.rds"))

getHazardRatio = function(x,digitsval=5){
  x                <- summary(x)
  HR               <- signif(x$coef[2], digits=digitsval)
  HR.confint.lower <- signif(x$conf.int[,"lower .95"],digitsval)
  HR.confint.upper <- signif(x$conf.int[,"upper .95"],digitsval)
  return(gsub("0\\.",".",paste0(HR, " ", HR.confint.lower, "-", HR.confint.upper, "")))
}

getHazardRatioNum = function(x){
  x     <- summary(x)
  HR    <- signif(x$coef[1,2], digits=5);#exp(beta)
  lower <- signif(x$conf.int[1,"lower .95"],5)
  upper <- signif(x$conf.int[1,"upper .95"],5)
  SE    <- (log(upper)-log(lower))/(2*1.96)
  return(list(HR=HR,logHR = log(HR), lower=lower, upper=upper, se=SE))
}
cDF  <- lapply(names(clinDF),function(name){clinDF[[name]] |> mutate(trial=name)}) |> 
  dplyr::bind_rows() |> 
  select(trial,patient_id, pfs=pfs_months, censored, treatment)

bDF  <- dplyr::bind_rows(biomDF) |> 
  mutate(value = ifelse(value %in% c("none","normal","FALSE","WT"),"normal",ifelse(is.na(value),"normal","altered"))) |> 
  inner_join(cDF,by="patient_id") |> 
  mutate(variable = as.character(variable)) |> 
  mutate(treatment = ifelse(treatment == "abiraterone", "AAP",treatment)) |> 
  mutate(treatment_class = ifelse(treatment %in% c("AAP","enzalutamide"),"ARD","PARPi"))

modelDF <- bDF |> 
  reshape2::dcast(patient_id + pfs + censored + treatment + trial + treatment_class ~ variable,value.var="value") |> 
  mutate(progressed = ifelse(censored == "progressed",1,0)) |> 
  mutate_at(vars(-patient_id,-pfs,-censored,-trial,-progressed),as.factor) |> 
  select(-censored) |> rename(NKX3_1 = `NKX3-1`)

cox <- function(y, x, start = NULL, weights = NULL, offset = NULL, ... ) {
  x <- x[, -1]
  coxph(formula = y ~ x)
}


# simple function to generate ggparty plots
party_node <- function(data,mapping){
  data$pfs = data[,'Surv(pfs, progressed).time']
  data$progressed = data[,'Surv(pfs, progressed).status']
  HR = coxph(Surv(pfs,progressed) ~ treatment_class,data=data) |> getHazardRatio(digits=2)
  pats = nrow(data)
  ggplot(data,aes_string(time="pfs",status="progressed",color="treatment_class")) + geom_km() + 
    theme(legend.position="none") + xlab("") + ylab("")
}

mob_partyfn <- function(mobmod,survname){
  
  splitLabs = apply(ggparty(mobmod)$data,1,function(row){
    il = row$info_list
    sprintf("%s-%s\nHR=%s",row$splitvar,il$nobs,il$object |> getHazardRatio(digitsval = 2))
  })
  
  HRS <- lapply(ggparty(mobmod)$data$info_list,function(il){
    sprintf("n=%s HR=%s",il$nobs,il$object |> getHazardRatio(digitsval = 2))
  })
  
  ggparty(mobmod) +
    geom_edge() +
    geom_edge_label(size=3) +
    geom_node_label(size=3,aes(label = splitLabs), ids = "inner") +
    geom_node_plot(plot_call = "party_node",gglist=list(),ids = "terminal",shared_axis_labels = T,shared_legend=F) +
    geom_node_label(size=3,aes(label = HRS),
                    fontface = "bold",
                    ids = "terminal",
                    nudge_y = 0.01,nudge_x = 0.015) 
}

3A Simple combined model

comb.train = modelDF |> filter(trial == "chinnaiyan" | trial == "chi")

comb.highcount <- comb.train |> 
  select(-treatment) |> 
  melt(id.vars=c("patient_id","pfs","treatment_class","progressed","trial")) |>
  mutate(variable=as.character(variable)) |> 
  filter(!is.na(value)) |> 
  group_by(variable,value,treatment_class) |> 
  summarise(measured = n(),n_trials=n_distinct(trial),trials=paste(unique(trial),collapse=",")) |> ungroup() |> 
  group_by(variable) |> 
  summarise(mincount = min(measured),dist_val = n_distinct(value,treatment_class),trials=max(n_trials)) |> 
  ungroup() |> 
  filter(mincount>9,dist_val==4,trials==2) |> 
  arrange(-mincount)
## Warning: attributes are not identical across measure variables; they will be
## dropped
## `summarise()` has grouped output by 'variable', 'value'. You can override using
## the `.groups` argument.
comb.highcount.genes = unique(comb.highcount$variable)
comb.highcount |> group_by(variable) |> summarize(mincount = min(mincount))
## # A tibble: 4 × 2
##   variable mincount
##   <chr>       <int>
## 1 AR             23
## 2 BRCA2          10
## 3 PTEN           21
## 4 TP53           18
formula = as.formula(sprintf("Surv(pfs,progressed) ~ treatment_class | %s",paste(comb.highcount.genes,collapse="+")))
mb  <- mob(formula = formula, data = comb.train, fit=cox,
  control = partykit::mob_control(bonferroni = F,alpha=1.0,minsize = 15,maxdepth = 5))
res <- mob_partyfn(mb)
fig3A <- res 
bDF |> group_by(variable,value,treatment_class) |> count() |> ungroup() |> 
  complete(variable,value,treatment_class,fill=list(n=0)) |> 
  group_by(variable) |> summarize(min_patients = min(n)) |> 
  filter(min_patients > 10)
## # A tibble: 7 × 2
##   variable    min_patients
##   <chr>              <int>
## 1 AR                    23
## 2 ATM                   17
## 3 BRCA1                 11
## 4 BRCA2                 27
## 5 ETS_Fusions           14
## 6 PTEN                  21
## 7 TP53                  18
mdf2.1 <- modelDF |> select(-treatment,-any) |> 
  mutate_at(vars(-patient_id,-pfs,-progressed,-treatment_class,-trial),.funs=as.character) |>
  mutate_at(vars(-patient_id,-pfs,-progressed,-treatment_class,-trial),.funs=~ifelse(is.na(.),"normal",.)) |> 
  mutate(`cell_cycle` = ifelse(TP53=="altered" | RB1=="altered" | MDM2=="altered" | CDKN2A=="altered" | 
                                 MYC=="altered" | CCND1=="altered","altered","normal")) |> 
  mutate(`PI3K_pathway` = ifelse(PTEN=="altered" | PIK3R1=="altered" | PIK3CA=="altered" | 
                                   PIK3CB=="altered" | AKT1=="altered","altered","normal")) |> 
  mutate(`WNT_pathway` = ifelse(APC=="altered" | CTNNB1=="altered" | RNF43=="altered","altered","normal")) |> 
  mutate(`DNA_repair` = ifelse(BRCA2=="altered" | BRCA1=="altered" | ATM=="altered" | CDK12=="altered" | MSH2=="altered" | 
                                 MSH6=="altered" | MLH1 == "altered" | FANCA == "altered","altered","normal")) |> 
  mutate(`AR_associated` = ifelse(FOXA1=="altered" | FOXP1=="altered" | ZBTB16=="altered" | AR == "altered","altered","normal")) |> 
  mutate(`Chromatin_modifiers` = ifelse(KMT2C=="altered" | KMT2D=="altered" | CHD1=="altered" | KDM6A=="altered","altered","normal"))

anyDF <- mdf2.1 |> data.table::setDT() |> 
  melt(id.vars=c("patient_id","pfs","treatment_class","progressed","trial")) |> 
  group_by(patient_id) |> summarise(any = ifelse(sum(value == "altered",na.rm = T)>2,"altered","normal")) |> ungroup()

mdf2 <- mdf2.1 |> inner_join(anyDF,by="patient_id")  |>
  mutate_at(vars(-patient_id,-pfs,-progressed,-treatment_class,-trial),.funs=as.factor) 

highcount <- mdf2 |> 
  data.table::setDT() |> 
  melt(id.vars=c("patient_id","pfs","treatment_class","progressed","trial")) |> 
  group_by(variable,value,treatment_class) |> summarize(patients=n_distinct(patient_id)) |> ungroup() |>
  complete(variable,value,treatment_class,fill=list(patients=0)) |> 
  group_by(variable) |> summarize(min_patients = min(patients)) |> 
  arrange(min_patients) |> 
  filter(min_patients > 10)
## Warning: attributes are not identical across measure variables; they will be
## dropped
## `summarise()` has grouped output by 'variable', 'value'. You can override using
## the `.groups` argument.
highcount.genes <- setdiff(unique(highcount$variable),c("DRD","HRDc","DRD_del"))
highcount.genes
##  [1] "BRCA1"         "ATM"           "TP53"          "ETS_Fusions"  
##  [5] "cell_cycle"    "PTEN"          "AR"            "PI3K_pathway" 
##  [9] "any"           "AR_associated" "BRCA2"         "DNA_repair"
formula = as.formula(sprintf("Surv(pfs,progressed) ~ treatment_class | %s",paste(highcount.genes,collapse=" + ")))

risks <- list()
models <- list()

pb <- progress::progress_bar$new(total=10)
for(i in 1:10){
  pb$tick()
  mb  <- mob(formula = formula,data=mdf2,fit=cox,
             control = mob_control(bonferroni = F,alpha=1.0,minsize = 5,mtry = 5))
  res <- mob_partyfn(mb)
  
  nod  <- predict.party(mb,mdf2,type="node")
  HRDF <- lapply(res$data$info_list,function(il){(il$object |> getHazardRatioNum())$HR}) |> unlist()
  models[[i]] = mb
  risks[[i]] = round(HRDF[nod],3)
}
## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.
## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.
## Error in chol.default(meat) : 
##   the leading minor of order 1 is not positive definite
## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.
## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Ran out of iterations and did not converge
## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.
## Error in chol.default(meat) : 
##   the leading minor of order 1 is not positive definite
## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.
## Error in chol.default(meat) : 
##   the leading minor of order 1 is not positive definite
## Error in chol.default(meat) : 
##   the leading minor of order 1 is not positive definite
## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.
## Error in chol.default(meat) : 
##   the leading minor of order 1 is not positive definite
## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.
## Error in chol.default(meat) : 
##   the leading minor of order 1 is not positive definite
## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.

## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.
## Error in chol.default(meat) : 
##   the leading minor of order 1 is not positive definite
## Error in chol.default(meat) : 
##   the leading minor of order 1 is not positive definite
## Warning in coxph.fit(X, Y, istrat, offset, init, control, weights = weights, :
## Loglik converged before variable 1 ; coefficient may be infinite.
m <- do.call(cbind,risks) |> (function(m){m[which(m>100)]<-NA;m})()
result <- mdf2
result$risk <- do.call(cbind,risks) |> (\(m){m[which(m>7)]<-NA;m})() |>  rowMeans(na.rm = T) + rnorm(n=nrow(mdf2))/1000

r = result |> group_by(risk) |> count() |> mutate(pr = n/nrow(mdf2)) |> ungroup() |> arrange(risk) |> mutate(cumn = cumsum(n),cump = cumsum(pr))
lorisk = (r |> filter(cump > 0.275) |> arrange(risk))$risk[1] + 0.0001
merisk = (r |> filter(cump > 0.66) |> arrange(risk))$risk[1] + 0.0001
r
## # A tibble: 334 × 5
##     risk     n      pr  cumn    cump
##    <dbl> <int>   <dbl> <int>   <dbl>
##  1 0.168     1 0.00299     1 0.00299
##  2 0.258     1 0.00299     2 0.00599
##  3 0.301     1 0.00299     3 0.00898
##  4 0.301     1 0.00299     4 0.0120 
##  5 0.302     1 0.00299     5 0.0150 
##  6 0.313     1 0.00299     6 0.0180 
##  7 0.315     1 0.00299     7 0.0210 
##  8 0.358     1 0.00299     8 0.0240 
##  9 0.409     1 0.00299     9 0.0269 
## 10 0.436     1 0.00299    10 0.0299 
## # … with 324 more rows
lorisk
## [1] 0.6351246
# cache the result so we can stop rebuilding.
# save.image("./cache.obj/mobforest.image")

plotDF <- result |> 
  mutate(BRCA2 = ifelse(BRCA2=="altered","PARPi strong preference","no preference")) |> 
  mutate(model = ifelse(risk < lorisk,"PARPi strong preference",ifelse(risk<merisk,"PARPi weak preference","no preference"))) |> 
  select(BRCA2,model,pfs,progressed,treatment_class,risk) |> 
  data.table::setDT() |> 
  melt(id.vars=c("pfs","progressed","treatment_class","risk")) |> 
  mutate(value = factor(value,levels=c("no preference","PARPi weak preference","PARPi strong preference"))) |> 
  mutate(coxHR = case_when(
    variable=="BRCA2" & value == "PARPi strong preference" ~ 
      coxph(Surv(pfs,progressed)~treatment_class,result |> filter(BRCA2=="altered")) |> getHazardRatio(digitsval = 2),
    variable=="BRCA2" & value == "no preference" ~ 
      coxph(Surv(pfs,progressed)~treatment_class,result |> filter(BRCA2!="altered")) |> getHazardRatio(digitsval = 2),
    variable=="model" & value == "PARPi strong preference" ~ 
      coxph(Surv(pfs,progressed)~treatment_class,result |> filter(risk < lorisk)) |> getHazardRatio(digitsval = 2),
    variable=="model" & value == "PARPi weak preference" ~ 
      coxph(Surv(pfs,progressed)~treatment_class,result |> filter(risk < merisk, risk > lorisk)) |> getHazardRatio(digitsval = 2),
    variable=="model" & value == "no preference" ~ 
      coxph(Surv(pfs,progressed)~treatment_class,result |> filter(risk>=merisk)) |> getHazardRatio(digitsval = 2)
  )) |> mutate(coxHR = sprintf("HR = %s",coxHR)) |> 
  group_by(variable,value) |> mutate(prev = n()/nrow(m)) |> ungroup() |> mutate(text = sprintf("%s\nprev=%.1f%%",coxHR,prev*100))

fig3D <- ggplot(plotDF,aes(time=pfs,status=progressed,color=treatment_class)) + 
  geom_km() + geom_kmticks() + geom_kmband() + 
  geom_text(data=plotDF |> select(variable,value,text) |> distinct() |>
              mutate(pfs=1,progressed=1,x=20,y=0.8),aes(x=x,y=y,label=text),color="black") + 
  facet_grid(variable~value) + theme_bw() + ylab("")  + xlab("")
fig3D

3B Chinnaiyan

chinn.train = modelDF |> filter(trial=="chinnaiyan") |> mutate(treatment_class=treatment)
chinn.highcount <- chinn.train |> 
  select(-treatment_class,-trial) |> 
  melt(id.vars=c("patient_id","pfs","treatment","progressed")) |>
  mutate(variable=as.character(variable)) |> 
  filter(!is.na(value)) |> 
  group_by(variable,value,treatment) |> 
  summarise(measured = n()) |> ungroup() |> 
  group_by(variable) |> 
  mutate(mincount = min(measured),dist_val = n_distinct(value,treatment)) |> ungroup() |> 
  filter(dist_val==4,mincount>10) |>
  arrange(-mincount)
## Warning: attributes are not identical across measure variables; they will be
## dropped
## `summarise()` has grouped output by 'variable', 'value'. You can override using
## the `.groups` argument.
chinn.genes = unique(chinn.highcount$variable)
chinn.genes
## [1] "TP53"        "PTEN"        "ETS_Fusions" "AR"
formula = as.formula(sprintf("Surv(pfs,progressed) ~ treatment_class | %s",paste(chinn.genes,collapse="+")))
mb2 <- mob(formula = formula, data = chinn.train, fit=cox,control = mob_control(bonferroni = F,alpha=1.0,minsize = 20))
fig3B <- mob_partyfn(mb2) 

3D Building populations

Which populations benefit from PARPi who don't have a BRCA2 mutation? who don't have an HRD mutation?

HRD.vs.model <- result |>  
  mutate(`HRD/DRD` = ifelse(BRCA2=="altered" | BRCA1=="altered" | ATM=="altered" | CDK12=="altered" | MSH2=="altered" | 
                                 MSH6=="altered" | MLH1 == "altered" | FANCA == "altered","HRD","NO HRD")) |> 
  mutate(risk = ifelse(risk < 0.63,"model prefers PARP-i","model no preference"))

HR1 = coxph(
    Surv(pfs,progressed)~treatment_class,
    HRD.vs.model |> filter(`HRD/DRD` == "HRD",risk=="model prefers PARP-i")) |>
  getHazardRatio(digitsval = 2)

HR2 = coxph(
    Surv(pfs,progressed)~treatment_class,
    HRD.vs.model |> filter(`HRD/DRD` == "NO HRD",risk=="model prefers PARP-i")) |>
  getHazardRatio(digitsval = 2)

HR3 = coxph(
    Surv(pfs,progressed)~treatment_class,
    HRD.vs.model |> filter(`HRD/DRD` == "HRD",risk=="model no preference")) |>
  getHazardRatio(digitsval = 2)

HR4 = coxph(
    Surv(pfs,progressed)~treatment_class,
    HRD.vs.model |> filter(`HRD/DRD` == "NO HRD",risk=="model no preference")) |>
  getHazardRatio(digitsval = 2)

HRD.vs.model.HR = HRD.vs.model |> mutate(HR = case_when(
  `HRD/DRD` == "HRD"    & risk=="model prefers PARP-i" ~ HR1, 
  `HRD/DRD` == "NO HRD" & risk=="model prefers PARP-i" ~ HR2,
  `HRD/DRD` == "HRD"    & risk=="model no preference"  ~ HR3,
  `HRD/DRD` == "NO HRD" & risk=="model no preference"  ~ HR4
  )) |> mutate(HR = sprintf("HR = %s",HR))

fig3C <- ggplot(HRD.vs.model,aes(time=pfs,status=progressed,col=treatment_class)) + 
  geom_km() + geom_kmticks() +  geom_kmband() +
  geom_text(data=HRD.vs.model.HR |> select(`HRD/DRD`,risk,HR) |> distinct() |>
              mutate(pfs=1,progressed=1,x=20,y=0.8),aes(x=x,y=y,label=HR),color="black") +
  facet_grid(`HRD/DRD`~risk) + theme_bw() + 
  xlab("") + ylab("") + theme(legend.position = "none")
fig3C

fig3A.clip <- wrap_elements(full = fig3A, clip = T)
fig3B.clip <- wrap_elements(full = fig3B, clip = T)
fig3C.clip <- wrap_elements(full = fig3C, clip = T)

(fig3A.clip + fig3B.clip + fig3D  + fig3C) + 
  plot_layout(nrow=2,ncol=2,heights=c(2,1),widths=c(3,2),guides="collect") +
  plot_annotation(tag_levels=list(c("A","B","C","D"))) & theme(legend.title = element_blank(),legend.position='bottom')