Update env.py
Browse files
env.py
CHANGED
|
@@ -224,11 +224,18 @@ def make_env(n_envs: int = 1, use_async_envs: bool = False, cfg=None) -> dict[st
|
|
| 224 |
combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
|
| 225 |
|
| 226 |
# 벤치마크인지 단일 태스크인지 구분
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
| 230 |
else:
|
| 231 |
-
task_names = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
|
| 233 |
|
| 234 |
out = defaultdict(dict)
|
|
|
|
| 224 |
combined_tasks = {**TARGET_TASKS, **PRETRAINING_TASKS}
|
| 225 |
|
| 226 |
# 벤치마크인지 단일 태스크인지 구분
|
| 227 |
+
parts = [t.strip() for t in task_name.split(",")]
|
| 228 |
+
if len(parts) == 1 and parts[0] in combined_tasks:
|
| 229 |
+
task_names = combined_tasks[parts[0]]
|
| 230 |
+
if gym_kwargs.get("split") is None:
|
| 231 |
+
gym_kwargs["split"] = "target" if parts[0] in TARGET_TASKS else "pretrain"
|
| 232 |
else:
|
| 233 |
+
task_names = []
|
| 234 |
+
for part in parts:
|
| 235 |
+
if part in combined_tasks:
|
| 236 |
+
task_names.extend(combined_tasks[part])
|
| 237 |
+
else:
|
| 238 |
+
task_names.append(part)
|
| 239 |
|
| 240 |
|
| 241 |
out = defaultdict(dict)
|