openpilot/tinygrad_repo/tinygrad/runtime/graph/remote.py
Vehicle Researcher c5d5c5d1f3 openpilot v0.10.1 release
date: 2025-10-24T00:30:59
master commit: 405631baf9685e171a0dd19547cb763f1b163d18
2025-10-24 00:31:03 -07:00

114 lines
6.2 KiB
Python

import time, itertools
from tinygrad.engine.jit import MultiGraphRunner
from tinygrad.engine.realize import CompiledRunner, BufferXfer, ExecItem
from tinygrad.device import Device, Compiled, Buffer
from tinygrad.runtime.ops_remote import RemoteDevice, RemoteConnection, RemoteRequest, GraphComputeItem, Transfer, GraphAlloc, GraphFree, GraphExec
from tinygrad.runtime.ops_remote import BatchTransfer, Event, Wait
from tinygrad.helpers import unwrap, flatten, dedup
from enum import Enum, auto
from dataclasses import replace
from collections import defaultdict
from typing import cast
class StagingType(Enum): NONE = auto(); GRAPH = auto(); TRANSFER = auto() # noqa: E702
def rd(dev:Compiled) -> RemoteDevice: return cast(RemoteDevice, dev)
def dev_key(dev:RemoteDevice): return dev.conn if dev.properties.graph_supports_multi else dev
def map_rawbuf(rawbuf:Buffer): return (cast(RemoteDevice, Device[rawbuf.device]).session, rawbuf._buf)
class RemoteGraph(MultiGraphRunner):
def __init__(self, jit_cache: list[ExecItem], rawbufs: list[Buffer], var_vals: dict[str, int]):
super().__init__(jit_cache, rawbufs, var_vals)
devices = dedup(flatten([[Device[unwrap(buf).device] for buf in ji.bufs] for ji in jit_cache]))
c2d = {device.conn: device for device in devices}
self.handle_indexes = {map_rawbuf(rawbufs[i]): i for i in sorted(dedup(self.input_replace.values()))}
self.template: list[RemoteRequest] = []
stagings: dict[RemoteDevice|RemoteConnection, list[GraphComputeItem|Transfer]] = defaultdict(list)
clobbered_buffers: set[Buffer] = set()
cur_staging_type: StagingType = StagingType.NONE
def _flush(new_staging_type:StagingType, force_break:bool=False):
nonlocal cur_staging_type
if cur_staging_type == new_staging_type and not force_break: return
# Pre-sync
if cur_staging_type == StagingType.TRANSFER:
for sdev,ddev in itertools.permutations(c2d.values(), 2):
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
self.template.append(Wait(event, session=ddev.session))
# Flush
for dev in devices:
dk = dev_key(dev)
staging = stagings[dk]
if not staging: continue
match cur_staging_type:
case StagingType.GRAPH:
bufs = tuple(map_rawbuf(rawbufs[i]) for i in sorted(dedup(self.input_replace.values())) if dev_key(rd(Device[rawbufs[i].device])) == dk)
dev.q(GraphAlloc(graph_num:=next(dev.graph_num), tuple(staging), tuple(bufs), var_vals))
self.template.append(GraphExec(graph_num, bufs, var_vals, wait=False, session=dev.session))
case StagingType.TRANSFER:
st = cast(list[Transfer], staging)
for host in dedup(t.dsession.host for t in st):
sbuffer_nums = [(unwrap(t.session), t.buffer_num) for t in st if t.dsession.host == host]
dbuffer_nums = [(t.dsession, t.dbuffer_num) for t in st if t.dsession.host == host]
self.template.append(BatchTransfer(sbuffer_nums, dbuffer_nums, session=dev.session))
staging.clear()
# Post-sync
if cur_staging_type == StagingType.TRANSFER:
for sdev,ddev in itertools.permutations(c2d.values(), 2):
self.template.append(Event(ddev.session, event:=next(ddev.event_num), session=sdev.session))
self.template.append(Wait(event, session=ddev.session))
cur_staging_type = new_staging_type
clobbered_buffers.clear()
for ji in jit_cache:
match ji.prg:
case CompiledRunner():
_flush(StagingType.GRAPH)
gi = GraphComputeItem(ji.prg.dev.session, ji.prg._prg.name, ji.prg._prg.datahash, tuple(unwrap(buf)._buf for buf in ji.bufs),
tuple(ji.prg.p.vars), ji.fixedvars, tuple(ji.prg.p.ins), tuple(ji.prg.p.outs),
tuple(ji.prg.p.global_size) if ji.prg.p.global_size is not None else None,
tuple(ji.prg.p.local_size) if ji.prg.p.local_size is not None else None)
stagings[dev_key(ji.prg.dev)].append(gi)
case BufferXfer():
dest, src = ji.bufs[0:2]
dest_dev, src_dev = cast(RemoteDevice, Device[unwrap(dest).device]), cast(RemoteDevice, Device[unwrap(src).device])
assert dest is not None and src is not None, ji
ti = Transfer(session=src_dev.session, buffer_num=src._buf, dsession=dest_dev.session, dbuffer_num=dest._buf)
if dev_key(dest_dev) == dev_key(src_dev):
_flush(StagingType.GRAPH)
stagings[dev_key(src_dev)].append(ti)
elif dest_dev.conn == src_dev.conn:
_flush(StagingType.NONE)
self.template.append(ti)
else:
_flush(StagingType.TRANSFER, force_break=src in clobbered_buffers)
clobbered_buffers.add(dest)
stagings[dev_key(src_dev)].append(ti)
case _: raise NotImplementedError(ji.prg)
_flush(StagingType.NONE)
def __del__(self):
for req in self.template:
match req:
case GraphExec(): RemoteConnection(unwrap(req.session).host).q(GraphFree(req.graph_num, session=req.session))
def __call__(self, rawbufs: list[Buffer], var_vals: dict[str, int], wait=False):
if wait: st = time.perf_counter()
rmap = {orig: map_rawbuf(rawbufs[replace_idx]) for orig,replace_idx in self.handle_indexes.items()}
for req in self.template:
match req:
case GraphExec():
req = replace(req, bufs=tuple(rmap[buf] for buf in req.bufs), var_vals=var_vals, wait=wait)
case Transfer():
if (req.session, req.buffer_num) in rmap: req = replace(req, buffer_num=rmap[(req.session, req.buffer_num)][1])
if (req.dsession, req.dbuffer_num) in rmap: req = replace(req, dbuffer_num=rmap[(req.dsession, req.dbuffer_num)][1])
case BatchTransfer():
req = replace(req, sbuffer_nums=[rmap.get(b, b) for b in req.sbuffer_nums], dbuffer_nums=[rmap.get(b, b) for b in req.dbuffer_nums])
case Event()|Wait():
pass # event number can be reused
case _: raise NotImplementedError(req)
RemoteConnection(unwrap(req.session).host).q(req)
if wait:
RemoteConnection(unwrap(req.session).host).batch_submit()
return time.perf_counter() - st