Skip to content
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Coming Soon...
- [X] Bookstack - by [@flifloo](https://github.com/flifloo) :pray:
- [X] Mattermost - by [@itaykal](https://github.com/Itaykal) :pray:
- [X] RocketChat - by [@flifloo](https://github.com/flifloo) :pray:
- [X] Stackoverflow Teams - by [@allen-munsch](https://github.com/allen-munsch) :pray:
- [ ] Gitlab Issues (In PR :pray:)
- [ ] Zendesk (In PR :pray:)
- [ ] Azure DevOps (In PR :pray:)
Expand Down
2 changes: 1 addition & 1 deletion app/api/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ async def list_locations(request: Request, data_source_name: str, config: dict)
@router.post("")
async def connect_data_source(request: Request, dto: AddDataSourceDto, background_tasks: BackgroundTasks) -> int:
logger.info(f"Adding data source {dto.name} with config {json.dumps(dto.config)}")
data_source = DataSourceContext.create_data_source(name=dto.name, config=dto.config)
data_source = await DataSourceContext.create_data_source(name=dto.name, config=dto.config)
Posthog.added_data_source(uuid=request.headers.get('uuid'), name=dto.name)
# in main.py we have a background task that runs every 5 minutes and indexes the data source
# but here we want to index the data source immediately
Expand Down
1 change: 1 addition & 0 deletions app/clear_ack_queue.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sqlite3 ~/.gerev/storage/tasks.sqlite3/data.db 'delete from ack_queue_task where _id in (select _id from ack_queue_task);'
1 change: 1 addition & 0 deletions app/clear_data_sources.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sqlite3 ~/.gerev/storage/db.sqlite3 'delete from data_source where id in (select id from data_source);'
2 changes: 1 addition & 1 deletion app/data_source/api/base_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_config_fields() -> List[ConfigField]:

@staticmethod
@abstractmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
"""
Validates the config and raises an exception if it's invalid.
"""
Expand Down
36 changes: 25 additions & 11 deletions app/data_source/api/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from data_source.api.base_data_source import BaseDataSource
from data_source.api.dynamic_loader import DynamicLoader, ClassInfo
from data_source.api.exception import KnownException
from db_engine import Session
from db_engine import Session, async_session
from pydantic.error_wrappers import ValidationError
from schemas import DataSourceType, DataSource

from sqlalchemy import select

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,22 +49,31 @@ def get_data_source_classes(cls) -> Dict[str, BaseDataSource]:
return cls._data_source_classes

@classmethod
def create_data_source(cls, name: str, config: dict) -> BaseDataSource:
with Session() as session:
data_source_type = session.query(DataSourceType).filter_by(name=name).first()
async def create_data_source(cls, name: str, config: dict) -> BaseDataSource:
async with async_session() as session:
data_source_type = await session.execute(
select(DataSourceType).filter_by(name=name)
)
data_source_type = data_source_type.scalar_one_or_none()
if data_source_type is None:
raise KnownException(message=f"Data source type {name} does not exist")

data_source_class = DynamicLoader.get_data_source_class(name)
logger.info(f"validating config for data source {name}")
data_source_class.validate_config(config)
await data_source_class.validate_config(config)
config_str = json.dumps(config)

data_source_row = DataSource(type_id=data_source_type.id, config=config_str, created_at=datetime.now())
data_source_row = DataSource(
type_id=data_source_type.id,
config=config_str,
created_at=datetime.now(),
)
session.add(data_source_row)
session.commit()
await session.commit()

data_source = data_source_class(config=config, data_source_id=data_source_row.id)
data_source = data_source_class(
config=config, data_source_id=data_source_row.id
)
cls._data_source_instances[data_source_row.id] = data_source

return data_source
Expand Down Expand Up @@ -95,8 +105,12 @@ def _load_connected_sources_from_db(cls):
for data_source in data_sources:
data_source_cls = DynamicLoader.get_data_source_class(data_source.type.name)
config = json.loads(data_source.config)
data_source_instance = data_source_cls(config=config, data_source_id=data_source.id,
last_index_time=data_source.last_indexed_at)
try:
data_source_instance = data_source_cls(config=config, data_source_id=data_source.id,
last_index_time=data_source.last_indexed_at)
except ValidationError as e:
logger.error(f"Error loading data source {data_source.id}: {e}")
return
cls._data_source_instances[data_source.id] = data_source_instance

cls._initialized = True
Expand Down
25 changes: 25 additions & 0 deletions app/data_source/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,36 @@
from functools import lru_cache
from io import BytesIO
from typing import Optional
import time
import threading
from functools import wraps

import requests


logger = logging.getLogger(__name__)


def rate_limit(*, allowed_per_second: int):
max_period = 1.0 / allowed_per_second
last_call = [time.perf_counter()]
lock = threading.Lock()

def decorate(func):
@wraps(func)
def limit(*args, **kwargs):
with lock:
elapsed = time.perf_counter() - last_call[0]
hold = max_period - elapsed
if hold > 0:
time.sleep(hold)
result = func(*args, **kwargs)
last_call[0] = time.perf_counter()
return result
return limit
return decorate


def snake_case_to_pascal_case(snake_case_string: str):
"""Converts a snake case string to a PascalCase string"""
components = snake_case_string.split('_')
Expand Down Expand Up @@ -55,3 +79,4 @@ def get_confluence_user_image(image_url: str, token: str) -> Optional[str]:
return f"data:image/jpeg;base64,{base64.b64encode(image_bytes.getvalue()).decode()}"
except:
logger.warning(f"Failed to get confluence user image {image_url}")

2 changes: 1 addition & 1 deletion app/data_source/sources/bookstack/bookstack.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def list_books(book_stack: BookStack) -> List[Dict]:
raise e

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
try:
parsed_config = BookStackConfig(**config)
book_stack = BookStack(url=parsed_config.url, token_id=parsed_config.token_id,
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/confluence/confluence.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def list_all_spaces(confluence: Confluence) -> List[Location]:
return spaces

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
try:
client = ConfluenceDataSource.confluence_client_from_config(config)
ConfluenceDataSource.list_spaces(confluence=client)
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/confluence/confluence_cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_config_fields() -> List[ConfigField]:
]

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
try:
client = ConfluenceCloudDataSource.confluence_client_from_config(config)
ConfluenceCloudDataSource.list_spaces(confluence=client)
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/google_drive/google_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_config_fields() -> List[ConfigField]:
]

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
try:
scopes = ['https://www.googleapis.com/auth/drive.readonly']
parsed_config = GoogleDriveConfig(**config)
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/mattermost/mattermost.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_config_fields() -> List[ConfigField]:
]

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
try:
parsed_config = MattermostConfig(**config)
maattermost = Driver(options=asdict(parsed_config))
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/rocketchat/rocketchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_display_name(cls) -> str:
return "Rocket.Chat"

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
rocket_chat_config = RocketchatConfig(**config)
should_verify_ssl = os.environ.get('ROCKETCHAT_VERIFY_SSL') is not None
rocket_chat = RocketChat(user_id=rocket_chat_config.token_id, auth_token=rocket_chat_config.token_secret,
Expand Down
2 changes: 1 addition & 1 deletion app/data_source/sources/slack/slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_config_fields() -> List[ConfigField]:
]

@staticmethod
def validate_config(config: Dict) -> None:
async def validate_config(config: Dict) -> None:
slack_config = SlackConfig(**config)
slack = WebClient(token=slack_config.token)
slack.auth_test()
Expand Down
Empty file.
136 changes: 136 additions & 0 deletions app/data_source/sources/stackoverflow/stackoverflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import logging
import time
from dataclasses import dataclass
from datetime import datetime
from typing import Dict, List, Optional
import requests

from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType, BaseDataSourceConfig
from data_source.api.basic_document import DocumentType, BasicDocument
from queues.index_queue import IndexQueue

from data_source.api.utils import rate_limit

logger = logging.getLogger(__name__)


@dataclass
class StackOverflowPost:
link: str
score: int
last_activity_date: int
creation_date: int
post_id: Optional[int] = None
post_type: Optional[str] = None
body_markdown: Optional[str] = None
owner_account_id: Optional[int] = None
owner_reputation: Optional[int] = None
owner_user_id: Optional[int] = None
owner_user_type: Optional[str] = None
owner_profile_image: Optional[str] = None
owner_display_name: Optional[str] = None
owner_link: Optional[str] = None
title: Optional[str] = None
last_edit_date: Optional[str] = None
tags: Optional[List[str]] = None
view_count: Optional[int] = None
article_id: Optional[int] = None
article_type: Optional[str] = None

class StackOverflowConfig(BaseDataSourceConfig):
api_key: str
team_name: str


@rate_limit(allowed_per_second=15)
def rate_limited_get(url, headers):
'''
https://api.stackoverflowteams.com/docs/throttle
https://api.stackexchange.com/docs/throttle
Every application is subject to an IP based concurrent request throttle.
If a single IP is making more than 30 requests a second, new requests will be dropped.
The exact ban period is subject to change, but will be on the order of 30 seconds to a few minutes typically.
Note that exactly what response an application gets (in terms of HTTP code, text, and so on)
is undefined when subject to this ban; we consider > 30 request/sec per IP to be very abusive and thus cut the requests off very harshly.
'''
resp = requests.get(url, headers=headers)
if resp.status_code == 429:
logger.warning('Rate limited, sleeping for 5 minutes')
time.sleep(300)
return rate_limited_get(url, headers)
return resp


class StackOverflowDataSource(BaseDataSource):

@staticmethod
def get_config_fields() -> List[ConfigField]:
return [
ConfigField(label="PAT API Key", name="api_key", type=HTMLInputType.TEXT),
ConfigField(label="Team Name", name="team_name", type=HTMLInputType.TEXT),
]

@staticmethod
async def validate_config(config: Dict) -> None:
so_config = StackOverflowConfig(**config)
url = f'https://api.stackoverflowteams.com/2.3/questions?&team={so_config.team_name}'
response = rate_limited_get(url, headers={'X-API-Access-Token': so_config.api_key})
response.raise_for_status()

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
so_config = StackOverflowConfig(**self._raw_config)
self._api_key = so_config.api_key
self._team_name = so_config.team_name

def _fetch_posts(self, *, api_key: str, team_name: str, page: int, doc_type: str) -> None:
team_fragment = f'&team={team_name}'
# this is a filter for "body markdown" inclusion, all filters are unique and static
# i am not entirely sure if this is per account, or usable by everyone
filter_fragment = '&filter=!nOedRLbqzB'
page_fragment = f'&page={page}'
# it looked like the timestamp was 10 digits, lets only look at stuff that is newer than the last index time
from_date_fragment = f'&fromdate={self._last_index_time.timestamp():.10n}'
url = f'https://api.stackoverflowteams.com/2.3/{doc_type}?{team_fragment}{filter_fragment}{page_fragment}{from_date_fragment}'
response = rate_limited_get(url, headers={'X-API-Access-Token': api_key})
response.raise_for_status()
response = response.json()
has_more = response['has_more']
items = response['items']
logger.info(f'Fetched {len(items)} {doc_type} from Stack Overflow')
for item_dict in items:
owner_fields = {}
if 'owner' in item_dict:
owner_fields = {f"owner_{k}": v for k, v in item_dict.pop('owner').items()}
if 'title' not in item_dict:
item_dict['title'] = item_dict['link']
post = StackOverflowPost(**item_dict, **owner_fields)
last_modified = datetime.fromtimestamp(post.last_edit_date or post.last_activity_date)
if last_modified < self._last_index_time:
return
logger.info(f'Feeding {doc_type} {post.title}')
post_document = BasicDocument(title=post.title, content=post.body_markdown, author=post.owner_display_name,
timestamp=datetime.fromtimestamp(post.creation_date), id=post.post_id,
data_source_id=self._data_source_id, location=post.link,
url=post.link, author_image_url=post.owner_profile_image,
type=DocumentType.MESSAGE)
IndexQueue.get_instance().put_single(doc=post_document)
if has_more:
# paginate onto the queue
self.add_task_to_queue(self._fetch_posts, api_key=self._api_key, team_name=self._team_name, page=page + 1, doc_type=doc_type)

def _feed_new_documents(self) -> None:
self.add_task_to_queue(self._fetch_posts, api_key=self._api_key, team_name=self._team_name, page=1, doc_type='posts')
# TODO: figure out how to get articles
# self.add_task_to_queue(self._fetch_posts, api_key=self._api_key, team_name=self._team_name, page=1, doc_type='articles')


# def test():
# import os
# config = {"api_key": os.environ['SO_API_KEY'], "team_name": os.environ['SO_TEAM_NAME']}
# so = StackOverflowDataSource(config=config, data_source_id=1)
# so._feed_new_documents()
#
#
# if __name__ == '__main__':
# test()
10 changes: 9 additions & 1 deletion app/db_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@

from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
# import base document and then register all classes
from schemas.base import Base

from paths import SQLITE_DB_PATH

engine = create_engine(f'sqlite:///{SQLITE_DB_PATH}')
db_url = f'sqlite:///{SQLITE_DB_PATH}'
print('DB engine path:', db_url)
engine = create_engine(db_url)
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)

async_db_url = db_url.replace('sqlite', 'sqlite+aiosqlite', 1)
print('ASYNC DB engine path:', async_db_url)
async_engine = create_async_engine(async_db_url)
async_session = sessionmaker(async_engine, expire_on_commit=False, class_=AsyncSession)
Loading