@@ -87,7 +87,7 @@ def path_to_uri(path, scheme="https://", domain="docs.ray.io"):
8787 return scheme + domain + path .split (domain )[- 1 ]
8888
8989
90- def parse_file (record ):
90+ def parse_html_file (record ):
9191 html_content = load_html_file (record ["path" ])
9292 if not html_content :
9393 return []
@@ -100,6 +100,17 @@ def parse_file(record):
100100 ]
101101
102102
103+ def parse_text_file (record ):
104+ with open (record ["path" ]) as f :
105+ text = f .read ()
106+ return [
107+ {
108+ "source" : str (record ["path" ]),
109+ "text" : text ,
110+ }
111+ ]
112+
113+
103114class EmbedChunks :
104115 def __init__ (self , model_name ):
105116 self .embedding_model = HuggingFaceEmbeddings (
@@ -139,6 +150,7 @@ def __call__(self, batch):
139150@app .command ()
140151def create_index (
141152 docs_path : Annotated [str , typer .Option (help = "location of data" )] = DOCS_PATH ,
153+ extension_type : Annotated [str , typer .Option (help = "type of data" )] = "html" ,
142154 embedding_model : Annotated [str , typer .Option (help = "embedder" )] = EMBEDDING_MODEL ,
143155 chunk_size : Annotated [int , typer .Option (help = "chunk size" )] = CHUNK_SIZE ,
144156 chunk_overlap : Annotated [int , typer .Option (help = "chunk overlap" )] = CHUNK_OVERLAP ,
@@ -148,11 +160,17 @@ def create_index(
148160
149161 # Dataset
150162 ds = ray .data .from_items (
151- [{"path" : path } for path in Path (docs_path ).rglob ("*.html" ) if not path .is_dir ()]
163+ [
164+ {"path" : path }
165+ for path in Path (docs_path ).rglob (f"*.{ extension_type } " )
166+ if not path .is_dir ()
167+ ]
152168 )
153169
154170 # Sections
155- sections_ds = ds .flat_map (parse_file )
171+ parser = parse_html_file if extension_type == "html" else parse_text_file
172+ sections_ds = ds .flat_map (parser )
173+ # TODO: do we really need to take_all()? Bring the splitter to the cluster
156174 sections = sections_ds .take_all ()
157175
158176 # Chunking
0 commit comments