43 lines
1.5 KiB (Stored with Git LFS)
Python
43 lines
1.5 KiB (Stored with Git LFS)
Python
"""Repository rule for remote GPU autoconfiguration.
|
|
|
|
This rule creates the starlark file
|
|
//third_party/toolchains/remote:execution.bzl
|
|
providing the function `gpu_test_tags`.
|
|
|
|
`gpu_test_tags` will return:
|
|
|
|
* `local`: if `REMOTE_GPU_TESTING` is false, allowing CPU tests to run
|
|
remotely and GPU tests to run locally in the same bazel invocation.
|
|
* `remote-gpu`: if `REMOTE_GPU_TESTING` is true; this allows rules to
|
|
set an execution requirement that enables a GPU-enabled remote platform.
|
|
"""
|
|
|
|
_REMOTE_GPU_TESTING = "REMOTE_GPU_TESTING"
|
|
|
|
def _flag_enabled(repository_ctx, flag_name):
|
|
if flag_name not in repository_ctx.os.environ:
|
|
return False
|
|
return repository_ctx.os.environ[flag_name].strip() == "1"
|
|
|
|
def _remote_execution_configure(repository_ctx):
|
|
# If we do not support remote gpu test execution, mark them as local, so we
|
|
# can combine remote builds with local gpu tests.
|
|
gpu_test_tags = "\"local\""
|
|
if _flag_enabled(repository_ctx, _REMOTE_GPU_TESTING):
|
|
gpu_test_tags = "\"remote-gpu\""
|
|
repository_ctx.template(
|
|
"remote_execution.bzl",
|
|
Label("//third_party/toolchains/remote:execution.bzl.tpl"),
|
|
{
|
|
"%{gpu_test_tags}": gpu_test_tags,
|
|
},
|
|
)
|
|
repository_ctx.template(
|
|
"BUILD",
|
|
Label("//third_party/toolchains/remote:BUILD.tpl"),
|
|
)
|
|
|
|
remote_execution_configure = repository_rule(
|
|
implementation = _remote_execution_configure,
|
|
environ = [_REMOTE_GPU_TESTING],
|
|
)
|