Skip to content

Commit 82f1ace

Browse files
committed
feat: action chunking optionally in inference script
1 parent d2e35b7 commit 82f1ace

1 file changed

Lines changed: 15 additions & 2 deletions

File tree

examples/inference/franka.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class InferenceConfig:
105105
on_same_machine: bool = False
106106
fps: int = FPS
107107
record_path: str = RECORD_PATH
108+
n_action_steps: int | None = None
108109

109110

110111
def load_inference_config() -> InferenceConfig:
@@ -132,6 +133,7 @@ def __init__(self, env: gym.Env, cfg: InferenceConfig):
132133
self.frame_rate = SimpleFrameRate(self._cfg.fps)
133134
self._listener = keyboard.Listener(on_press=self._on_press)
134135
self._listener.start()
136+
self._action_buffer = []
135137

136138
def _on_press(self, key):
137139
try:
@@ -166,6 +168,18 @@ def obs_rcs2agents(self, obs: dict, info: dict | None = None) -> Obs:
166168

167169
return Obs(cameras=cameras, gripper=None, info=info, state=np.concatenate(state))
168170

171+
def act(self, obs_dict) -> None:
172+
done = False
173+
if self._cfg.n_action_steps is None:
174+
return self.remote_agent.act(obs_dict)
175+
if len(self._action_buffer) == 0:
176+
action = self.remote_agent.act(obs_dict)
177+
selected_action = action.action[:self._cfg.n_action_steps]
178+
self._action_buffer = selected_action.tolist()
179+
done = action.done
180+
act = self._action_buffer.pop(0)
181+
return Act(action=act, done=done)
182+
169183
def action_agents2rcs(self, action: Act) -> dict[str, Any]:
170184
act = {}
171185
for idx, robot in enumerate(self._cfg.robot_keys):
@@ -242,7 +256,7 @@ def loop(self):
242256
sleep(0.05)
243257
continue
244258

245-
action = self.remote_agent.act(copy.deepcopy(obs_dict))
259+
action = self.act(copy.deepcopy(obs_dict))
246260
if action.done:
247261
logger.info("done issued by agent, resetting environment")
248262
obs, _ = self.env.reset()
@@ -370,7 +384,6 @@ def main():
370384
# env = RHCWrapper(env, exec_horizon=1)
371385

372386
controller = ModelInference(env_rel, cfg)
373-
input("robot is about to be controlled by AI, press enter to enable keyboard control")
374387
with env_rel:
375388
controller.loop()
376389

0 commit comments

Comments
 (0)