-
Notifications
You must be signed in to change notification settings - Fork 779
Open
Labels
Description
Hello, I’m exploring how to combine LangGraph with process rewards and handle node failures in sequential agents. I have two main questions that aren’t covered in existing tutorials:
Process reward + LangGraph
Is it possible to call emit_reward directly within an agent’s node to emit intermediate rewards during training? For example:
import agentlightning as agl
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.graph.state import CompiledStateGraph
class Agent:
def ___init___(...):
...
def node_1(self, state: State) -> State:
result = self.invoke_prompt(prompt)
reward_1 = ...
agl.emit_reward(reward_1)
new_state = ...
return new_state
def node_2(self, state: State) -> State:
result = self.invoke_prompt(prompt)
reward_2 = ...
agl.emit_reward(reward_2)
new_state = ...
return new_state
def graph(self) -> CompiledStateGraph[State]:
builder = StateGraph(State)
builder.add_node(self.node1)
builder.add_node(self.node2)
builder.add_edge(START, "node1")
builder.add_edge("node1", "node2")
builder.add_edge("node2", END)
return builder.compile()
class LitSQLAgent(agl.LitAgent[Dict[str, Any]]):
def __init__(...):
...
def rollout(
self,
task: Dict[str, Any],
resources: agl.NamedResources,
rollout: agl.Rollout,
) -> None:
agent = Agent(...).graph()
try:
# Required to make the langchain tracing work
handler = self.tracer.get_langchain_handler()
result = agent.invoke(
{"question": question},
{"callbacks": [handler] if handler else [], "recursion_limit": 100},
)
except Exception as e:
logger.exception(f"[Rollout {rollout_id}] Error during agent invocation: {e}")
return
returnHandle node failures
If I should add a confitional edge for each intermediate node?
class Agent:
... # nearly the same as the agent above
def check_1(self, state: State):
if is_training and not is_success(state):
return END
else:
return "node2"
def graph(self) -> CompiledStateGraph[State]:
builder = StateGraph(State)
builder.add_node(self.node1)
builder.add_node(self.node2)
builder.add_edge(START, "node1")
builder.add_edge("node2", END)
builder.add_conditional_edges("node1", self.check_1)
return builder.compile()