Whalswp commited on
Commit
eb2032b
·
verified ·
1 Parent(s): 6bf0bea

Update env.py

Browse files
Files changed (1) hide show
  1. env.py +11 -4
env.py CHANGED
@@ -224,11 +224,18 @@ def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[st
224
  combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
225
 
226
  # 벤치마크인지 단일 태스크인지 구분
227
- if task_name in combined_tasks:
228
- task_names = combined_tasks[task_name]
229
- gym_kwargs["split"] = "target" if task_name in TARGET_TASKS else "pretrain"
 
 
230
  else:
231
- task_names = [t.strip() for t in task_name.split(",")]
 
 
 
 
 
232
 
233
 
234
  out = defaultdict(dict)
 
224
  combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
225
 
226
  # 벤치마크인지 단일 태스크인지 구분
227
+ parts = [t.strip() for t in task_name.split(",")]
228
+ if len(parts) == 1 and parts[0] in combined_tasks:
229
+ task_names = combined_tasks[parts[0]]
230
+ if gym_kwargs.get("split") is None:
231
+ gym_kwargs["split"] = "target" if parts[0] in TARGET_TASKS else "pretrain"
232
  else:
233
+ task_names = []
234
+ for part in parts:
235
+ if part in combined_tasks:
236
+ task_names.extend(combined_tasks[part])
237
+ else:
238
+ task_names.append(part)
239
 
240
 
241
  out = defaultdict(dict)