class Wayformer(nn.Module):
def forward(
self,
target_valid: Tensor,
target_type: Tensor,
target_attr: Tensor,
other_valid: Tensor,
other_attr: Tensor,
tl_valid: Tensor,
tl_attr: Tensor,
map_valid: Tensor,
map_attr: Tensor,
inference_repeat_n: int = 1,
inference_cache_map: bool = False,
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Args:
target_type: [n_scene, n_target, 3]
# target history, other history, map
if pl_aggr:
target_valid: [n_scene, n_target], bool
target_attr: [n_scene, n_target, agent_attr_dim]
other_valid: [n_scene, n_target, n_other], bool
other_attr: [n_scene, n_target, n_other, agent_attr_dim]
map_valid: [n_scene, n_target, n_map], bool
map_attr: [n_scene, n_target, n_map, map_attr_dim]
else:
target_valid: [n_scene, n_target, n_step_hist], bool
target_attr: [n_scene, n_target, n_step_hist, agent_attr_dim]
other_valid: [n_scene, n_target, n_other, n_step_hist], bool
other_attr: [n_scene, n_target, n_other, n_step_hist, agent_attr_dim]
map_valid: [n_scene, n_target, n_map, n_pl_node], bool
map_attr: [n_scene, n_target, n_map, n_pl_node, map_attr_dim]
# traffic lights: cannot be aggregated, detections are not tracked.
if use_current_tl:
tl_valid: [n_scene, n_target, 1, n_tl], bool
tl_attr: [n_scene, n_target, 1, n_tl, tl_attr_dim]
else:
tl_valid: [n_scene, n_target, n_step_hist, n_tl], bool
tl_attr: [n_scene, n_target, n_step_hist, n_tl, tl_attr_dim]
Returns: will be compared to "output/gt_pos": [n_scene, n_agent, n_step_future, 2]
valid: [n_scene, n_target]
conf: [n_decoder, n_scene, n_target, n_pred], not normalized!
pred: [n_decoder, n_scene, n_target, n_pred, n_step_future, pred_dim]
"""