We are running our own TGI container and trying to boot Mistral Instruct. It’s dieing trying to utilize Flash Attention 2. I know this is because I am using a T4 GPU, but for the life of me I can’t figure out how to tell TGI not to use Flash Attention 2.
Somehow, when we deploy it through HuggingFace on an AWS T4, it knows. Looking at the logs for HF deployment I see:
- 2024-08-01T01:48:41.507+00:00 {“timestamp”:“2024-08-01T01:48:41.507739Z”,“level”:“INFO”,“fields”:{“message”:“Successfully downloaded weights.”},“target”:“text_generation_launcher”,“span”:{“name”:“download”},“spans”:[{“name”:“download”}]}
- 2024-08-01T01:48:41.507+00:00 {“timestamp”:“2024-08-01T01:48:41.507904Z”,“level”:“INFO”,“fields”:{“message”:“Starting shard”},“target”:“text_generation_launcher”,“span”:{“rank”:0,“name”:“shard-manager”},“spans”:[{“rank”:0,“name”:“shard-manager”}]}
- 2024-08-01T01:48:41.507+00:00 {“timestamp”:“2024-08-01T01:48:41.507914Z”,“level”:“INFO”,“fields”:{“message”:“Starting shard”},“target”:“text_generation_launcher”,“span”:{“rank”:1,“name”:“shard-manager”},“spans”:[{“rank”:1,“name”:“shard-manager”}]}
- 2024-08-01T01:48:41.507+00:00 {“timestamp”:“2024-08-01T01:48:41.507952Z”,“level”:“INFO”,“fields”:{“message”:“Starting shard”},“target”:“text_generation_launcher”,“span”:{“rank”:2,“name”:“shard-manager”},“spans”:[{“rank”:2,“name”:“shard-manager”}]}
- 2024-08-01T01:48:41.508+00:00 {“timestamp”:“2024-08-01T01:48:41.507977Z”,“level”:“INFO”,“fields”:{“message”:“Starting shard”},“target”:“text_generation_launcher”,“span”:{“rank”:3,“name”:“shard-manager”},“spans”:[{“rank”:3,“name”:“shard-manager”}]}
- 2024-08-01T01:48:46.187+00:00 {“timestamp”:“2024-08-01T01:48:46.187088Z”,“level”:“WARN”,“fields”:{“message”:“Unable to use Flash Attention V2: GPU with CUDA capability 7 5 is not supported for Flash Attention V2\n”},“target”:“text_generation_launcher”}
- 2024-08-01T01:48:46.204+00:00 {“timestamp”:“2024-08-01T01:48:46.204816Z”,“level”:“WARN”,“fields”:{“message”:“Unable to use Flash Attention V2: GPU with CUDA capability 7 5 is not supported for Flash Attention V2\n”},“target”:“text_generation_launcher”}
- 2024-08-01T01:48:46.210+00:00 {“timestamp”:“2024-08-01T01:48:46.210281Z”,“level”:“WARN”,“fields”:{“message”:“Unable to use Flash Attention V2: GPU with CUDA capability 7 5 is not supported for Flash Attention V2\n”},“target”:“text_generation_launcher”}
- 2024-08-01T01:48:46.210+00:00 {“timestamp”:“2024-08-01T01:48:46.210527Z”,“level”:“WARN”,“fields”:{“message”:“Unable to use Flash Attention V2: GPU with CUDA capability 7 5 is not supported for Flash Attention V2\n”},“target”:“text_generation_launcher”}
- 2024-08-01T01:48:51.516+00:00 {“timestamp”:“2024-08-01T01:48:51.516157Z”,“level”:“INFO”,“fields”:{“message”:“Waiting for shard to be ready…”},“target”:“text_generation_launcher”,“span”:{“rank”:3,“name”:“shard-manager”},“spans”:[{“rank”:3,“name”:“shard-manager”}]}
- 2024-08-01T01:48:51.516+00:00 {“timestamp”:“2024-08-01T01:48:51.516157Z”,“level”:“INFO”,“fields”:{“message”:“Waiting for shard to be ready…”},“target”:“text_generation_launcher”,“span”:{“rank”:2,“name”:“shard-manager”},“spans”:[{“rank”:2,“name”:“shard-manager”}]}
- 2024-08-01T01:48:51.516+00:00 {“timestamp”:“2024-08-01T01:48:51.516157Z”,“level”:“INFO”,“fields”:{“message”:“Waiting for shard to be ready…”},“target”:“text_generation_launcher”,“span”:{“rank”:0,“name”:“shard-manager”},“spans”:[{“rank”:0,“name”:“shard-manager”}]}
- 2024-08-01T01:48:51.516+00:00 {“timestamp”:“2024-08-01T01:48:51.516157Z”,“level”:“INFO”,“fields”:{“message”:“Waiting for shard to be ready…”},“target”:“text_generation_launcher”,“span”:{“rank”:1,“name”:“shard-manager”},“spans”:[{“rank”:1,“name”:“shard-manager”}]}
Notice it correctly says it can’t use Flash Attention V2.
When we run this on our own T4 in our AWS, we don’t see this output. Instead, it bombs out trying to use Flash Attention v2 with:
Traceback (most recent call last):
File “/opt/conda/bin/text-generation-server”, line 8, in
sys.exit(app())
File “/opt/conda/lib/python3.10/site-packages/typer/main.py”, line 311, in call
return get_command(self)(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/click/core.py”, line 1157, in call
return self.main(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/typer/core.py”, line 778, in main
return _main(
File “/opt/conda/lib/python3.10/site-packages/typer/core.py”, line 216, in _main
rv = self.invoke(ctx)
File “/opt/conda/lib/python3.10/site-packages/click/core.py”, line 1688, in invoke
return _process_result(sub_ctx.command.invoke(sub_ctx))
File “/opt/conda/lib/python3.10/site-packages/click/core.py”, line 1434, in invoke
return ctx.invoke(self.callback, **ctx.params)
File “/opt/conda/lib/python3.10/site-packages/click/core.py”, line 783, in invoke
return __callback(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/typer/main.py”, line 683, in wrapper
return callback(**use_params) # type: ignore
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/cli.py”, line 109, in serve
server.serve(
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py”, line 274, in serve
asyncio.run(
File “/opt/conda/lib/python3.10/asyncio/runners.py”, line 44, in run
return loop.run_until_complete(main)
File “/opt/conda/lib/python3.10/asyncio/base_events.py”, line 636, in run_until_complete
self.run_forever()
File “/opt/conda/lib/python3.10/asyncio/base_events.py”, line 603, in run_forever
self._run_once()
File “/opt/conda/lib/python3.10/asyncio/base_events.py”, line 1909, in _run_once
handle._run()
File “/opt/conda/lib/python3.10/asyncio/events.py”, line 80, in _run
self._context.run(self._callback, *self._args)
File “/opt/conda/lib/python3.10/site-packages/grpc_interceptor/server.py”, line 165, in invoke_intercept_method
return await self.intercept(
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/interceptor.py”, line 21, in intercept
return await response
File “/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py”, line 120, in _unary_interceptor
raise error
File “/opt/conda/lib/python3.10/site-packages/opentelemetry/instrumentation/grpc/_aio_server.py”, line 111, in _unary_interceptor
return await behavior(request_or_iterator, context)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/server.py”, line 123, in Warmup
max_supported_total_tokens = self.model.warmup(batch)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_causal_lm.py”, line 1099, in warmup
_, batch, _ = self.generate_token(batch)
File “/opt/conda/lib/python3.10/contextlib.py”, line 79, in inner
return func(*args, **kwds)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_causal_lm.py”, line 1374, in generate_token
out, speculative_logits = self.forward(batch, adapter_data)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/flash_causal_lm.py”, line 1299, in forward
logits, speculative_logits = self.model.forward(
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py”, line 516, in forward
hidden_states = self.model(
File “/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1562, in _call_impl
return forward_call(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py”, line 440, in forward
hidden_states, residual = layer(
File “/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1562, in _call_impl
return forward_call(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py”, line 365, in forward
attn_output = self.self_attn(
File “/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1562, in _call_impl
return forward_call(*args, **kwargs)
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/models/custom_modeling/flash_mistral_modeling.py”, line 218, in forward
attn_output = attention(
File “/opt/conda/lib/python3.10/site-packages/text_generation_server/layers/attention/cuda.py”, line 221, in attention
return flash_attn_2_cuda.varlen_fwd(
RuntimeError: FlashAttention only supports Ampere GPUs or newer.
Does anyone know how I am supposed to tell TGI not to use Flash Attention v2, or how HF TGI deployment is getting just a warning and moving on?