Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions resources_servers/textworld/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# TextWorld Resources Server

Integrates: https://github.com/microsoft/TextWorld

Native multi-turn text adventure environments for RL training.

## Quick Start

### 1. Generate Dataset

```bash
python3 resources_servers/textworld/scripts/generate_games.py --workers 32
# Output: resources_servers/textworld/games/ with train/val/test splits
```

### 2. Create Training Examples

```bash
python3 resources_servers/textworld/scripts/create_examples.py \
--all --split train \
--output resources_servers/textworld/data/train.jsonl

python3 resources_servers/textworld/scripts/create_examples.py \
--all --split val \
--output resources_servers/textworld/data/val.jsonl

python3 resources_servers/textworld/scripts/create_examples.py \
--all --split test \
--output resources_servers/textworld/data/test.jsonl
```

### 3. Start Servers

```bash
vllm serve Qwen/Qwen3-30B-A3B \
--dtype auto \
--tensor-parallel-size 8 \
--gpu-memory-utilization 0.9 \
--enable-auto-tool-choice --tool-call-parser hermes \
--host 0.0.0.0 \
--port 10240

ng_run "+config_paths=[resources_servers/textworld/configs/textworld.yaml,responses_api_models/vllm_model/configs/vllm_model.yaml]"
```

### 4. Collect Rollouts

```bash
ng_collect_rollouts +agent_name=textworld_simple_agent \
+input_jsonl_fpath=resources_servers/textworld/data/train.jsonl \
+output_jsonl_fpath=resources_servers/textworld/data/rollouts_train.jsonl \
+limit=5
```

### 5. View Results

```bash
ng_viewer +jsonl_fpath=resources_servers/textworld/data/rollouts_example.jsonl
```

## Testing

```bash
ng_test +entrypoint=resources_servers/textworld
```

## Validation

```bash
python3 resources_servers/textworld/scripts/validate_dataset.py resources_servers/textworld/games
```

169 changes: 169 additions & 0 deletions resources_servers/textworld/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict

from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel, ConfigDict, Field

import textworld
from textworld import EnvInfos
from textworld.core import Environment

from nemo_gym.base_resources_server import (
BaseResourcesServerConfig,
BaseSeedSessionRequest,
BaseSeedSessionResponse,
BaseVerifyRequest,
BaseVerifyResponse,
SimpleResourcesServer,
)
from nemo_gym.server_utils import SESSION_ID_KEY


class TextworldResourcesServerConfig(BaseResourcesServerConfig):
expose_admissible_commands: bool = False


class TextworldSeedSessionRequest(BaseSeedSessionRequest):
game_file: str


class TextworldSeedSessionResponse(BaseSeedSessionResponse):
initial_observation: str
objective: str
admissible_commands: list[str] | None = None


class ExecuteCommandRequest(BaseModel):
command: str


class ExecuteCommandResponse(BaseModel):
observation: str
score: int
done: bool
won: bool
admissible_commands: list[str] | None = None


class TextworldVerifyRequest(BaseVerifyRequest):
pass


class TextworldResourcesServer(SimpleResourcesServer):
config: TextworldResourcesServerConfig
session_id_to_env: Dict[str, Environment] = Field(default_factory=dict)

model_config = ConfigDict(arbitrary_types_allowed=True)

def setup_webserver(self) -> FastAPI:
app = super().setup_webserver()

app.post("/execute_command")(self.execute_command)

return app

async def seed_session(
self, request: Request, body: TextworldSeedSessionRequest
) -> TextworldSeedSessionResponse:
session_id = request.session[SESSION_ID_KEY]

request_infos = EnvInfos(
feedback=True,
won=True,
lost=True,
score=True,
max_score=True,
objective=True,
admissible_commands=self.config.expose_admissible_commands,
)

from pathlib import Path

games_dir = Path(__file__).parent / "games"

# Try direct path first
game_path = games_dir / body.game_file

# Hacky: if game not found, search in train/val/test subdirectories TODO: Fix paths
if not game_path.exists():
for split in ["train", "val", "test"]:
for game_type in ["coin_collector", "treasure_hunter", "simple", "cooking", "custom"]:
potential_path = games_dir / split / game_type / body.game_file
if potential_path.exists():
game_path = potential_path
break
if game_path.exists():
break

if not game_path.exists():
raise FileNotFoundError(f"Game file not found: {body.game_file} (searched in {games_dir})")

env = textworld.start(str(game_path), request_infos=request_infos)

state = env.reset()

self.session_id_to_env[session_id] = env

response = TextworldSeedSessionResponse(
initial_observation=state.feedback if hasattr(state, "feedback") else str(state),
objective=state["objective"] if "objective" in state else "",
)

if self.config.expose_admissible_commands and "admissible_commands" in state:
response.admissible_commands = state["admissible_commands"]

return response

async def execute_command(
self, request: Request, body: ExecuteCommandRequest
) -> ExecuteCommandResponse:
session_id = request.session[SESSION_ID_KEY]

if session_id not in self.session_id_to_env:
raise HTTPException(
status_code=400,
detail="Session not initialized. Please call seed_session first.",
)

env = self.session_id_to_env[session_id]

state, score, done = env.step(body.command)

response = ExecuteCommandResponse(
observation=state.feedback if hasattr(state, "feedback") else str(state),
score=score,
done=done,
won=state.get("won", False),
)

if self.config.expose_admissible_commands and "admissible_commands" in state:
response.admissible_commands = state["admissible_commands"]

return response

async def verify(self, request: Request, body: TextworldVerifyRequest) -> BaseVerifyResponse:
session_id = request.session[SESSION_ID_KEY]

reward = 0.0
if session_id in self.session_id_to_env:
env = self.session_id_to_env[session_id]
if hasattr(env, "state") and env.state.get("won", False):
reward = 1.0

return BaseVerifyResponse(**body.model_dump(), reward=reward)


if __name__ == "__main__":
TextworldResourcesServer.run_webserver()
21 changes: 21 additions & 0 deletions resources_servers/textworld/configs/textworld.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
textworld_resources_server:
resources_servers:
textworld:
entrypoint: app.py
expose_admissible_commands: false
textworld_simple_agent:
responses_api_agents:
simple_agent:
entrypoint: app.py
max_steps: 15
resources_server:
type: resources_servers
name: textworld_resources_server
model_server:
type: responses_api_models
name: policy_model
datasets:
- name: example
type: example
jsonl_fpath: resources_servers/textworld/data/example.jsonl
num_repeats: 1
5 changes: 5 additions & 0 deletions resources_servers/textworld/data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
*train.jsonl
*validation.jsonl
*train_prepare.jsonl
*validation_prepare.jsonl
*example_prepare.jsonl
Loading
Loading