diff --git a/pocketflow/__init__.py b/pocketflow/__init__.py index 32ed3dcd..0b7368fc 100644 --- a/pocketflow/__init__.py +++ b/pocketflow/__init__.py @@ -1,7 +1,29 @@ -import asyncio, warnings, copy, time +import asyncio, warnings, copy, time, sys class BaseNode: - def __init__(self): self.params,self.successors={},{} + def __init__(self): + self.params, self.successors = {}, {} + self.name = self.get_instance_name() or f"node_{hash(self)}" + self.flow = None # Will be set by Flow._propagate_flow + self.parent = None # Will be set by Flow._propagate_flow + + def get_instance_name(self): + """Find the variable name this instance is assigned to, if any""" + try: + frame = sys._getframe(1) + while frame: + for scope in (frame.f_locals, frame.f_globals): + for key, value in scope.items(): + if value is self and not key.startswith('_') and key != 'self': + return key + frame = frame.f_back + except (AttributeError, ValueError): + pass + return None + + def _get_name(self): + """Return the instance name, either from name attribute or lookup""" + return self.name or self.get_instance_name() or f"node_{hash(self)}" def set_params(self,params): self.params=params def add_successor(self,node,action="default"): if action in self.successors: warnings.warn(f"Overwriting successor for action '{action}'") @@ -11,8 +33,8 @@ def exec(self,prep_res): pass def post(self,shared,prep_res,exec_res): pass def _exec(self,prep_res): return self.exec(prep_res) def _run(self,shared): p=self.prep(shared);e=self._exec(p);return self.post(shared,p,e) - def run(self,shared): - if self.successors: warnings.warn("Node won't run successors. Use Flow.") + def run(self,shared): + if self.successors: warnings.warn("Node won't run successors. Use Flow.") return self._run(shared) def __rshift__(self,other): return self.add_successor(other) def __sub__(self,action): @@ -37,22 +59,45 @@ class BatchNode(Node): def _exec(self,items): return [super(BatchNode,self)._exec(i) for i in items] class Flow(BaseNode): - def __init__(self,start): super().__init__();self.start=start + def __init__(self, start, name=None): + super().__init__() + self.start = start + self.name = name or self.get_instance_name() or f"flow_{hash(self)}" + self._propagate_flow(self.start) + + def _propagate_flow(self, node, visited=None): + """Set flow and parent references on all nodes in the flow""" + if visited is None: + visited = set() + + if node is None or id(node) in visited: + return + + visited.add(id(node)) + node.flow = self + node.parent = self.parent if hasattr(self, 'parent') else None + + for successor in node.successors.values(): + self._propagate_flow(successor, visited) def get_next_node(self,curr,action): nxt=curr.successors.get(action or "default") if not nxt and curr.successors: warnings.warn(f"Flow ends: '{action}' not found in {list(curr.successors)}") return nxt - def _orch(self,shared,params=None): - curr,p=copy.copy(self.start),(params or {**self.params}) - while curr: curr.set_params(p);c=curr._run(shared);curr=copy.copy(self.get_next_node(curr,c)) + def _orch(self, shared, params=None): + curr, p = copy.copy(self.start), (params or {**self.params}) + while curr: + curr.set_params(p) + c = curr._run(shared) + curr = copy.copy(self.get_next_node(curr, c)) def _run(self,shared): pr=self.prep(shared);self._orch(shared);return self.post(shared,pr,None) def exec(self,prep_res): raise RuntimeError("Flow can't exec.") class BatchFlow(Flow): - def _run(self,shared): - pr=self.prep(shared) or [] - for bp in pr: self._orch(shared,{**self.params,**bp}) - return self.post(shared,pr,None) + def _run(self, shared): + pr = self.prep(shared) or [] + for bp in pr: + self._orch(shared, {**self.params, **bp}) + return self.post(shared, pr, None) class AsyncNode(Node): def prep(self,shared): raise RuntimeError("Use prep_async.") @@ -64,14 +109,14 @@ async def prep_async(self,shared): pass async def exec_async(self,prep_res): pass async def exec_fallback_async(self,prep_res,exc): raise exc async def post_async(self,shared,prep_res,exec_res): pass - async def _exec(self,prep_res): + async def _exec(self,prep_res): for i in range(self.max_retries): try: return await self.exec_async(prep_res) except Exception as e: if i==self.max_retries-1: return await self.exec_fallback_async(prep_res,e) if self.wait>0: await asyncio.sleep(self.wait) - async def run_async(self,shared): - if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.") + async def run_async(self,shared): + if self.successors: warnings.warn("Node won't run successors. Use AsyncFlow.") return await self._run_async(shared) async def _run_async(self,shared): p=await self.prep_async(shared);e=await self._exec(p);return await self.post_async(shared,p,e) @@ -82,19 +127,26 @@ class AsyncParallelBatchNode(AsyncNode,BatchNode): async def _exec(self,items): return await asyncio.gather(*(super(AsyncParallelBatchNode,self)._exec(i) for i in items)) class AsyncFlow(Flow,AsyncNode): - async def _orch_async(self,shared,params=None): - curr,p=copy.copy(self.start),(params or {**self.params}) - while curr:curr.set_params(p);c=await curr._run_async(shared) if isinstance(curr,AsyncNode) else curr._run(shared);curr=copy.copy(self.get_next_node(curr,c)) + async def _orch_async(self, shared, params=None): + curr, p = copy.copy(self.start), (params or {**self.params}) + while curr: + curr.set_params(p) + if isinstance(curr, AsyncNode): + c = await curr._run_async(shared) + else: + c = curr._run(shared) + curr = copy.copy(self.get_next_node(curr, c)) async def _run_async(self,shared): p=await self.prep_async(shared);await self._orch_async(shared);return await self.post_async(shared,p,None) class AsyncBatchFlow(AsyncFlow,BatchFlow): - async def _run_async(self,shared): - pr=await self.prep_async(shared) or [] - for bp in pr: await self._orch_async(shared,{**self.params,**bp}) - return await self.post_async(shared,pr,None) + async def _run_async(self, shared): + pr = await self.prep_async(shared) or [] + for bp in pr: + await self._orch_async(shared, {**self.params, **bp}) + return await self.post_async(shared, pr, None) class AsyncParallelBatchFlow(AsyncFlow,BatchFlow): - async def _run_async(self,shared): - pr=await self.prep_async(shared) or [] - await asyncio.gather(*(self._orch_async(shared,{**self.params,**bp}) for bp in pr)) - return await self.post_async(shared,pr,None) \ No newline at end of file + async def _run_async(self, shared): + pr = await self.prep_async(shared) or [] + await asyncio.gather(*(self._orch_async(shared, {**self.params, **bp}) for bp in pr)) + return await self.post_async(shared, pr, None) diff --git a/pocketflow/example.txt b/pocketflow/example.txt new file mode 100644 index 00000000..57907086 --- /dev/null +++ b/pocketflow/example.txt @@ -0,0 +1 @@ +fasdasdasd diff --git a/pocketflow/rework.py b/pocketflow/rework.py new file mode 100644 index 00000000..0a335c34 --- /dev/null +++ b/pocketflow/rework.py @@ -0,0 +1,237 @@ + +# from pocketflow import * +from __init__ import * +import os + +def call_llm(prompt): + # Your API logic here + return prompt + +class LoadFile(Node): + def __init__(self, name=None): + super().__init__() + # Use provided name or fall back to automatic lookup + self.name = name or self.name + def prep(self, shared): + print(f" In : {self.__class__.__name__}") + """Load file from disk""" + filename = self.params["filename"] + with open(filename, "r") as file: + return file.read() + + def exec(self, prep_res): + """Return file content""" + return prep_res + + def post(self, shared, prep_res, exec_res): + """Store file content in shared""" + shared["file_content"] = exec_res + return "default" + + +class GetOpinion(Node): + def __init__(self, name=None): + super().__init__() + # Use provided name or fall back to automatic lookup + self.name = name or self.name + + def prep(self, shared): + print(f" In : {self.__class__.__name__}") + print(f"My name is: {self.name} (instance of {self.__class__.__name__})") + if self.flow: + print(f"Flow name: {self.flow.name}") + if self.flow.parent: + print(f"Parent flow: {self.flow.parent.name}") + """Get file content from shared""" + if not shared.get("reworked_file_content"): + return shared["file_content"] + else: + return "Original text :\n" + shared["file_content"] + "Revised version:\n" + shared["reworked_file_content"] + + def exec(self, prep_res): + """Ask LLM for opinion on file content""" + prompt = f"What's your opinion on this text: {prep_res}. Provide opinion on how to make it better." + return call_llm(prompt) + + def post(self, shared, prep_res, exec_res): + """Store opinion in shared""" + shared["opinion"] = exec_res + return "default" + +class GetValidation(Node): + def __init__(self, name=None): + super().__init__() + # Use provided name or fall back to automatic lookup + self.name = name or self.name + def prep(self, shared): + print(f" In : {self.__class__.__name__}") + """Get file content from shared""" + shared['discussion'] = shared["file_content"] + shared["opinion"] + "Final revised text : " + shared["reworked_file_content"] + return + + def exec(self, prep_res): + """Ask LLM for opinion on file content""" + prompt = f"Validate that the final revised text is valid and reflects the changes proposed in opinion : {prep_res}. \nReply `IS VALID` if it is of `NOT VALID` if it needs some more work." + return call_llm(prompt) + + def post(self, shared, prep_res, exec_res): + """Store rework count in shared""" + if "IS VALID" in exec_res: + return "default" + else: + return "invalid" + + +class ReworkFile(Node): + def __init__(self, name=None): + super().__init__() + # Use provided name or fall back to automatic lookup + self.name = name or self.name + def prep(self, shared): + print(f" In : {self.__class__.__name__}") + """Get file content and opinion from shared""" + return shared["file_content"], shared["opinion"] + + def exec(self, prep_res): + """Ask LLM to rework file based on opinion""" + file_content, opinion = prep_res + prompt = f"Rework this text based on the opinion: {opinion}\n\nOriginal text: {file_content}" + return call_llm(prompt) + + def post(self, shared, prep_res, exec_res): + """Store reworked file content in shared""" + if "rework2_flow_min_count" in self.params: + rework_count = self.params["rework2_flow_min_count"] + shared["reworked_file_content"] = exec_res + if not shared.get("reworked_file_content_count"): + shared["reworked_file_content_count"] = 1 + elif shared.get("reworked_file_content_count"): + shared["reworked_file_content_count"] += 1 + + if shared["reworked_file_content_count"] < rework_count: + print(f"Less than {self.params["rework2_flow_min_count"]} rework for rework2_flow, so going for pass #{shared["reworked_file_content_count"]}.") + return "rework" + else: + return "default" + else: + shared["reworked_file_content"] = exec_res + + +class SaveFile(Node): + def __init__(self, name=None): + super().__init__() + # Use provided name or fall back to automatic lookup + self.name = name or self.name + def prep(self, shared): + print(f" In : {self.__class__.__name__}") + """Get reworked file content and original filename from shared""" + filename = self.params["filename"] + if "reworked_file_content" in shared: + return shared["reworked_file_content"], filename + else: + print("Error") + + def exec(self, prep_res): + """Save reworked file content to new file""" + reworked_file_content, filename = prep_res + new_filename = f"{filename.split('.')[0]}_v2.{filename.split('.')[-1]}" + with open(new_filename, "w") as file: + file.write(reworked_file_content) + return reworked_file_content + + def post(self, shared, prep_res, exec_res): + filename = self.params["filename"] + """Return success message""" + print(f"Saved to {filename} the content : \n{exec_res}") + + +# # # Comment this from here +# # First flow +# Create nodes +load_Node = LoadFile(name="load_Node") +opinion_Node = GetOpinion(name="opinion_Node") +rework_Node = ReworkFile(name="rework_Node") +save_Node = SaveFile(name="save_Node") + +# Connect nodes +load_Node >> opinion_Node >> rework_Node >> save_Node + +# Create flow +rework_Flow = Flow(start=load_Node,name="rework_Flow") + +# Set flow params +rework_Flow.set_params({"filename": "example.txt"}) +# Run flow +shared = {} +rework_Flow.run(shared) +# # # To here for second workflow to work + +# # Second flow +# Create nodes +load2_Node = LoadFile(name="load2_Node") +opinion2_Node = GetOpinion(name="opinion2_Node") +rework2_Node = ReworkFile(name="rework2_Node") +valid2_Node = GetValidation(name="valid2_Node") +save2_Node = SaveFile(name="save2_Node") + +print(f" NAME is : {opinion2_Node.name}") + +# Connect nodes +load2_Node >> opinion2_Node +opinion2_Node >> rework2_Node + +rework2_Node - "default" >> valid2_Node +rework2_Node - "rework" >> opinion2_Node + +# Get second opinion it if rework asked because in rework_flow2 and less than 2 rework +valid2_Node - "invalid" >> opinion2_Node +valid2_Node - "default" >> save2_Node + +# Create flow with explicit name +rework2_Flow = Flow(start=load2_Node, name="rework2_Flow") +# rework2_Flow.name = "rework2_Flow" # Set explicit name + +# Set flow params +# This will not set params if class Flow was already initialized with other params ? +rework2_Flow.set_params({"filename": "example.txt", "rework2_flow_min_count" : 3}) + +# Run flow +shared2 = {} +rework2_Flow.run(shared2) + +def build_mermaid(start): + visited, lines = set(), ["graph LR"] + + def get_name(n): + """Get the node's name for use in the diagram""" + if isinstance(n, Flow): + return n._get_name() + return n._get_name().replace(' ', '_') # Mermaid needs no spaces in node names + + def link(a, b): + lines.append(f" {get_name(a)} --> {get_name(b)}") + def walk(node, parent=None): + if node in visited: + return parent and link(parent, node) + visited.add(node) + if isinstance(node, Flow): + node.start and parent and link(parent, node) + # Add flow name and class name to subgraph label + flow_label = f"{node._get_name()} ({type(node).__name__})" + lines.append(f"\n subgraph {get_name(node)}[\"{flow_label}\"]") + node.start and walk(node.start) + for nxt in node.successors.values(): + node.start and walk(nxt, node.start) or (parent and link(parent, nxt)) or walk(nxt) + lines.append(" end\n") + else: + # Add both instance name and class name to node label + node_label = f"{node._get_name()} ({type(node).__name__})" + lines.append(f" {get_name(node)}[\"{node_label}\"]") + parent and link(parent, node) + [walk(nxt, node) for nxt in node.successors.values()] + walk(start) + return "\n".join(lines) + +print(build_mermaid(start=rework_Flow)) + +print(build_mermaid(start=rework2_Flow))