|
| 1 | +import asyncio |
1 | 2 | from contextlib import asynccontextmanager
|
2 | 3 | from dataclasses import dataclass
|
3 | 4 | from datetime import timedelta
|
4 |
| -from typing import AsyncIterator |
| 5 | +from functools import partial |
| 6 | +from pathlib import Path |
| 7 | +from typing import AsyncIterator, Callable, ParamSpec, TypeVar |
5 | 8 |
|
6 | 9 | from cloudkv import AsyncCloudKV
|
7 | 10 |
|
@@ -30,3 +33,43 @@ async def lock(self, agent_name: str) -> AsyncIterator[bool]:
|
30 | 33 | await self.cloud_kv.delete(key)
|
31 | 34 | else:
|
32 | 35 | yield False
|
| 36 | + |
| 37 | + |
| 38 | +@dataclass |
| 39 | +class LocalStorage(SelfImprovingAgentStorage): |
| 40 | + directory: Path = Path('.self-improving-agent') |
| 41 | + |
| 42 | + def __post_init__(self): |
| 43 | + self.directory.mkdir(exist_ok=True) |
| 44 | + |
| 45 | + async def get_patch(self, agent_name: str) -> ModelContextPatch | None: |
| 46 | + file = self.directory / f'{agent_name}.json' |
| 47 | + if file.exists(): |
| 48 | + content = await asyncify(file.read_bytes) |
| 49 | + return ModelContextPatch.model_validate_json(content) |
| 50 | + |
| 51 | + async def set_patch(self, agent_name: str, patch: ModelContextPatch, expires: timedelta) -> None: |
| 52 | + # note we're ignoring expiry here |
| 53 | + file = self.directory / f'{agent_name}.json' |
| 54 | + content = patch.model_dump_json(indent=2) |
| 55 | + await asyncify(file.write_text, content) |
| 56 | + |
| 57 | + @asynccontextmanager |
| 58 | + async def lock(self, agent_name: str) -> AsyncIterator[bool]: |
| 59 | + file = self.directory / f'lock:{agent_name}' |
| 60 | + if not await asyncify(file.exists): |
| 61 | + await asyncify(file.touch) |
| 62 | + try: |
| 63 | + yield True |
| 64 | + finally: |
| 65 | + await asyncify(file.unlink) |
| 66 | + else: |
| 67 | + yield False |
| 68 | + |
| 69 | + |
| 70 | +P = ParamSpec('P') |
| 71 | +R = TypeVar('R') |
| 72 | + |
| 73 | + |
| 74 | +async def asyncify(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R: |
| 75 | + return await asyncio.get_event_loop().run_in_executor(None, partial(func, *args, **kwargs)) |
0 commit comments