Skip to content

Commit 964e859

Browse files
authored
Merge pull request #60 from GerevAI/task-queue
Task queue
2 parents 711610d + 928335c commit 964e859

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+813
-449
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""document id_in_data_source
2+
3+
Revision ID: 792a820e9374
4+
Revises: 9c2f5b290b16
5+
Create Date: 2023-03-26 11:27:05.341609
6+
7+
"""
8+
from alembic import op
9+
import sqlalchemy as sa
10+
11+
12+
# revision identifiers, used by Alembic.
13+
revision = '792a820e9374'
14+
down_revision = '9c2f5b290b16'
15+
branch_labels = None
16+
depends_on = None
17+
18+
19+
def upgrade() -> None:
20+
op.add_column('document', sa.Column('id_in_data_source', sa.String(length=64), default='__none__'))
21+
22+
23+
def downgrade() -> None:
24+
op.drop_column('document', 'id_in_data_source')

app/alembic/versions/9c2f5b290b16_add_fields_to_datasourcetype_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from alembic import op
1111
import sqlalchemy as sa
1212

13-
from data_source_api.utils import get_class_by_data_source_name
13+
from data_source.api.dynamic_loader import DynamicLoader
1414
from db_engine import Session
1515
from schemas import DataSourceType
1616

@@ -29,7 +29,7 @@ def upgrade() -> None:
2929
# update existing data sources
3030
data_source_types = session.query(DataSourceType).all()
3131
for data_source_type in data_source_types:
32-
data_source_class = get_class_by_data_source_name(data_source_type.name)
32+
data_source_class = DynamicLoader.get_data_source_class(data_source_type.name)
3333
config_fields = data_source_class.get_config_fields()
3434

3535
data_source_type.config_fields = json.dumps([config_field.dict() for config_field in config_fields])

app/api/data_source.py

Lines changed: 29 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,18 @@
11
import base64
22
import json
3-
from datetime import datetime
43
from typing import List
54

6-
from fastapi import APIRouter, BackgroundTasks
5+
from fastapi import APIRouter
76
from pydantic import BaseModel
8-
from starlette.responses import Response
7+
from starlette.background import BackgroundTasks
98

10-
from data_source_api.base_data_source import ConfigField
11-
from data_source_api.exception import KnownException
12-
from data_source_api.utils import get_class_by_data_source_name
9+
from data_source.api.base_data_source import ConfigField
10+
from data_source.api.context import DataSourceContext
1311
from db_engine import Session
1412
from schemas import DataSourceType, DataSource
1513

1614
router = APIRouter(
17-
prefix='/data-source',
15+
prefix='/data-sources',
1816
)
1917

2018

@@ -39,50 +37,44 @@ def from_data_source_type(data_source_type: DataSourceType) -> 'DataSourceTypeDt
3937
)
4038

4139

42-
@router.get("/list-types")
40+
class ConnectedDataSourceDto(BaseModel):
41+
id: int
42+
name: str
43+
44+
45+
@router.get("/types")
4346
async def list_data_source_types() -> List[DataSourceTypeDto]:
4447
with Session() as session:
4548
data_source_types = session.query(DataSourceType).all()
4649
return [DataSourceTypeDto.from_data_source_type(data_source_type)
4750
for data_source_type in data_source_types]
4851

4952

50-
@router.get("/list-connected")
51-
async def list_connected_data_sources() -> List[str]:
53+
@router.get("/connected")
54+
async def list_connected_data_sources() -> List[ConnectedDataSourceDto]:
5255
with Session() as session:
5356
data_sources = session.query(DataSource).all()
54-
return [data_source.type.name for data_source in data_sources]
57+
return [ConnectedDataSourceDto(id=data_source.id, name=data_source.type.name)
58+
for data_source in data_sources]
5559

5660

5761
class AddDataSource(BaseModel):
5862
name: str
5963
config: dict
6064

6165

62-
@router.post("/add")
66+
@router.delete("/{data_source_id}")
67+
async def delete_data_source(data_source_id: int):
68+
DataSourceContext.delete_data_source(data_source_id=data_source_id)
69+
return {"success": "Data source deleted successfully"}
70+
71+
72+
@router.post("")
6373
async def add_integration(dto: AddDataSource, background_tasks: BackgroundTasks):
64-
with Session() as session:
65-
data_source_type = session.query(DataSourceType).filter_by(name=dto.name).first()
66-
if data_source_type is None:
67-
return {"error": "Data source type does not exist"}
68-
69-
data_source_class = get_class_by_data_source_name(dto.name)
70-
try:
71-
data_source_class.validate_config(dto.config)
72-
except KnownException as e:
73-
return Response(e.message, status_code=501)
74-
75-
config_str = json.dumps(dto.config)
76-
ds = DataSource(type_id=data_source_type.id, config=config_str, created_at=datetime.now())
77-
session.add(ds)
78-
session.commit()
79-
80-
data_source_id = session.query(DataSource).filter_by(type_id=data_source_type.id)\
81-
.order_by(DataSource.id.desc()).first().id
82-
data_source = data_source_class(config=dto.config, data_source_id=data_source_id)
83-
84-
# in main.py we have a background task that runs every 5 minutes and indexes the data source
85-
# but here we want to index the data source immediately
86-
background_tasks.add_task(data_source.index)
87-
88-
return {"success": "Data source added successfully"}
74+
data_source = DataSourceContext.create_data_source(name=dto.name, config=dto.config)
75+
76+
# in main.py we have a background task that runs every 5 minutes and indexes the data source
77+
# but here we want to index the data source immediately
78+
background_tasks.add_task(data_source.index)
79+
80+
return {"success": "Data source added successfully"}
File renamed without changes.
File renamed without changes.

app/data_source_api/base_data_source.py renamed to app/data_source/api/base_data_source.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from abc import abstractmethod, ABC
33
from datetime import datetime
44
from enum import Enum
5-
from typing import Dict, List, Optional
5+
from typing import Dict, List, Optional, Callable
66
import re
77

88
from pydantic import BaseModel
99

1010
from db_engine import Session
11+
from queues.task_queue import TaskQueue, Task
1112
from schemas import DataSource
1213

1314

@@ -80,16 +81,37 @@ def __init__(self, config: Dict, data_source_id: int, last_index_time: datetime
8081
if last_index_time is None:
8182
last_index_time = datetime(2012, 1, 1)
8283
self._last_index_time = last_index_time
84+
self._last_task_time = None
8385

84-
def _set_last_index_time(self) -> None:
86+
def _save_index_time_in_db(self) -> None:
87+
"""
88+
Sets the index time in the database, to be now
89+
"""
8590
with Session() as session:
8691
data_source: DataSource = session.query(DataSource).filter_by(id=self._data_source_id).first()
8792
data_source.last_indexed_at = datetime.now()
8893
session.commit()
8994

90-
def index(self) -> None:
95+
def add_task_to_queue(self, function: Callable, **kwargs):
96+
task = Task(data_source_id=self._data_source_id,
97+
function_name=function.__name__,
98+
kwargs=kwargs)
99+
TaskQueue.get_instance().add_task(task)
100+
101+
def run_task(self, function_name: str, **kwargs) -> None:
102+
self._last_task_time = datetime.now()
103+
function = getattr(self, function_name)
104+
function(**kwargs)
105+
106+
def index(self, force: bool = False) -> None:
107+
if self._last_task_time is not None and not force:
108+
# Don't index if the last task was less than an hour ago
109+
time_since_last_task = datetime.now() - self._last_task_time
110+
if time_since_last_task.total_seconds() < 60 * 60:
111+
logging.info("Skipping indexing data source because it was indexed recently")
112+
91113
try:
92-
self._set_last_index_time()
114+
self._save_index_time_in_db()
93115
self._feed_new_documents()
94116
except Exception as e:
95117
logging.exception("Error while indexing data source")

app/data_source_api/basic_document.py renamed to app/data_source/api/basic_document.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime
22
from dataclasses import dataclass
33
from enum import Enum
4+
from typing import Union
45

56

67
class DocumentType(Enum):
@@ -32,7 +33,7 @@ def from_mime_type(cls, mime_type: str):
3233

3334
@dataclass
3435
class BasicDocument:
35-
id: int
36+
id: Union[int, str]
3637
data_source_id: int
3738
type: DocumentType
3839
title: str
@@ -44,3 +45,7 @@ class BasicDocument:
4445
url: str
4546
file_type: FileType = None
4647

48+
@property
49+
def id_in_data_source(self):
50+
return str(self.data_source_id) + '_' + str(self.id)
51+

app/data_source/api/context.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import json
2+
from datetime import datetime
3+
from typing import Dict, List
4+
5+
from data_source.api.base_data_source import BaseDataSource
6+
from data_source.api.dynamic_loader import DynamicLoader, ClassInfo
7+
from data_source.api.exception import KnownException
8+
from db_engine import Session
9+
from schemas import DataSourceType, DataSource
10+
11+
12+
class DataSourceContext:
13+
"""
14+
This class is responsible for loading data sources and caching them.
15+
It dynamically loads data source types from the data_source/sources directory.
16+
It loads data sources from the database and caches them.
17+
"""
18+
_initialized = False
19+
_data_sources: Dict[int, BaseDataSource] = {}
20+
21+
@classmethod
22+
def get_data_source(cls, data_source_id: int) -> BaseDataSource:
23+
if not cls._initialized:
24+
cls.init()
25+
cls._initialized = True
26+
27+
return cls._data_sources[data_source_id]
28+
29+
@classmethod
30+
def create_data_source(cls, name: str, config: dict) -> BaseDataSource:
31+
with Session() as session:
32+
data_source_type = session.query(DataSourceType).filter_by(name=name).first()
33+
if data_source_type is None:
34+
raise KnownException(message=f"Data source type {name} does not exist")
35+
36+
data_source_class = DynamicLoader.get_data_source_class(name)
37+
data_source_class.validate_config(config)
38+
config_str = json.dumps(config)
39+
40+
data_source_row = DataSource(type_id=data_source_type.id, config=config_str, created_at=datetime.now())
41+
session.add(data_source_row)
42+
session.commit()
43+
44+
data_source = data_source_class(config=config, data_source_id=data_source_row.id)
45+
cls._data_sources[data_source_row.id] = data_source
46+
47+
return data_source
48+
49+
@classmethod
50+
def delete_data_source(cls, data_source_id: int):
51+
with Session() as session:
52+
data_source = session.query(DataSource).filter_by(id=data_source_id).first()
53+
if data_source is None:
54+
raise KnownException(message=f"Data source {data_source_id} does not exist")
55+
56+
session.delete(data_source)
57+
session.commit()
58+
59+
del cls._data_sources[data_source_id]
60+
61+
@classmethod
62+
def init(cls):
63+
cls._add_data_sources_to_db()
64+
cls._load_context_from_db()
65+
66+
@classmethod
67+
def _load_context_from_db(cls):
68+
with Session() as session:
69+
data_sources: List[DataSource] = session.query(DataSource).all()
70+
for data_source in data_sources:
71+
data_source_cls = DynamicLoader.get_data_source_class(data_source.type.name)
72+
config = json.loads(data_source.config)
73+
data_source_instance = data_source_cls(config=config, data_source_id=data_source.id,
74+
last_index_time=data_source.last_indexed_at)
75+
cls._data_sources[data_source.id] = data_source_instance
76+
77+
cls._initialized = True
78+
79+
@classmethod
80+
def _add_data_sources_to_db(cls):
81+
data_sources: Dict[str, ClassInfo] = DynamicLoader.find_data_sources()
82+
83+
with Session() as session:
84+
for source_name in data_sources.keys():
85+
if session.query(DataSourceType).filter_by(name=source_name).first():
86+
continue
87+
88+
class_info = data_sources[source_name]
89+
data_source_class = DynamicLoader.get_class(file_path=class_info.file_path,
90+
class_name=class_info.name)
91+
92+
config_fields = data_source_class.get_config_fields()
93+
config_fields_str = json.dumps([config_field.dict() for config_field in config_fields])
94+
new_data_source = DataSourceType(name=source_name,
95+
display_name=data_source_class.get_display_name(),
96+
config_fields=config_fields_str)
97+
session.add(new_data_source)
98+
session.commit()

0 commit comments

Comments
 (0)