feat(router): Dynamic batch sizing#210
Conversation
|
Nice! I need to think about your implementation and maybe play with it but I think it's a good idea. |
c89260f to
ba1aae3
Compare
|
Thanks @OlivierDehaene, I've now rebased it. |
|
@OlivierDehaene @Narsil continuing discussion from #246, I've pushed a new commit here to abstract the batch "weight" calculations to cover the non-flash attention case too. We use this for example with flan-t5-xxl where we set the max batch weight to I too am not that happy about how complex the changes are but I'm sure they can still be simplified/restructured a bit, and the queue-jumping logic could be removed or improved as discussed. |
|
If I understand correctly, currently we send the same batch to all workers, and then each worker run the same tokenization repeatedly, even though the model is sharded. I wonder if we could let the router split the batch and send different samples to different workers, so that we can avoid running tokenizer on same samples from different workers. To be honest, I don't understand why we don't split the batch like FSDP or DeepSpeed does. Would it be possible to further reduce memory usage from the split batches? |
|
Closing as stale. Thanks for the contribution Nick, happy to take some back now that we're back on Apache ! Cheers. |
Motivation
Currently to avoid OOM you must set a "worst case" max batch size based on the desired max sequence length. This means that (a) throughput is unnecessarily limited when there are many shorter sequences and (b) you have to be pretty conservative about the max context length offered.
These changes introduce a maximum batch "weight" parameter which in the flash attention case corresponds to a maximum total number of tokens in the batch. The idea is that this is roughly proportional to the memory requirement.
max_new_tokensvaluesIf
max_batch_weightis not set, it just infers this from themax_batch_sizeandmax_total_tokensargs. In this case it should behave roughly the same as it does now, so could hopefully be a "non breaking" change for existing configurationsIt turns out to be simpler to configure for a particular model/GPU. The precise values for
max_batch_sizeandmax_sequence_lengthno longer matter much, they can both be set quite high. You just need to determine one number (the max weight / total tokens), which is easy to do with minimal experimentation.We have been using this successfully for a while now and it means we can support a much higher throughput / volume of users with the same hardware while offering larger context lengths. For example, we have a deployment of GPT-NeoX 20B on one 80GB A100 with the max batch size set to 256 and the max sequence length (max_total_tokens) set to 8192. The actual batch size flexes automatically as needed. Our
max_batch_weightsetting for this is 10k.Details/caveats
I've only included the implementation for the flash attention case so far. The additions to generalize to the regular attention case aren't very big (we run non flash-attention models with this too), but I thought this was probably complicated enough to start with. It will need to support general case of course before actually being included.next_batchshould return immediately before getting into the more complex logic.next_batchfunction now takes the current entries map instead of a min and max, the tests inqueue.rswould need updating, so I just removed them for now.