Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
461 views
in Technique[技术] by (71.8m points)

r - Fast linear regression by group

I have 500K users and I need to compute a linear regression (with intercept) for each of them.

Each user has around 30 records.

I tried with dplyr and lm and this is way too slow. Around 2 sec by user.

  df%>%                       
      group_by(user_id, add =  FALSE) %>%
      do(lm = lm(Y ~ x, data = .)) %>%
      mutate(lm_b0 = summary(lm)$coeff[1],
             lm_b1 = summary(lm)$coeff[2]) %>%
      select(user_id, lm_b0, lm_b1) %>%
      ungroup()
    )

I tried to use lm.fit which is known to be faster but it doesn't seem to be compatible with dplyr.

Is there a fast way to do a linear regression by group?

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Answer

0 votes
by (71.8m points)

You can just use the basic formulas for calculating slope and regression. lm does a lot of unnecessary things if all you care about are those two numbers. Here I use data.table for the aggregation, but you could do it in base R as well (or dplyr):

system.time(
  res <- DT[, 
    {
      ux <- mean(x)
      uy <- mean(y)
      slope <- sum((x - ux) * (y - uy)) / sum((x - ux) ^ 2)
      list(slope=slope, intercept=uy - slope * ux)
    }, by=user.id
  ]
)

Produces for 500K users ~30 obs each (in seconds):

 user  system elapsed 
 7.35    0.00    7.36 

Or about 15 microseconds per user.

Update: I ended up writing a bunch of blog posts that touch on this as well.

And to confirm this is working as expected:

> summary(DT[user.id==89663, lm(y ~ x)])$coefficients
             Estimate Std. Error   t value  Pr(>|t|)
(Intercept) 0.1965844  0.2927617 0.6714826 0.5065868
x           0.2021210  0.5429594 0.3722580 0.7120808
> res[user.id == 89663]
   user.id    slope intercept
1:   89663 0.202121 0.1965844

Data:

set.seed(1)
users <- 5e5
records <- 30
x <- runif(users * records)
DT <- data.table(
  x=x, y=x + runif(users * records) * 4 - 2, 
  user.id=sample(users, users * records, replace=T)
)

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome to OStack Knowledge Sharing Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...