Skip to content

Commit f9ca6ea

Browse files
authored
Merge pull request #50 from victordibia/dev
Update explanation and expansion visualizations (brushed bar charts, arc line charts)
2 parents 873c2cb + 2f6c1e1 commit f9ca6ea

File tree

21 files changed

+2692
-15904
lines changed

21 files changed

+2692
-15904
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ NeuralQA is comprised of several high level modules:
3838

3939
- **Reader**: For each retrieved passage, a BERT based model predicts a span that contains the answer to the question. In practice, retrieved passages may be lengthy and BERT based models can process a maximum of 512 tokens at a time. NeuralQA handles this in two ways. Lengthy passages are chunked into smaller sections with a configurable stride. Secondly, NeuralQA offers the option of extracting a subset of relevant snippets (RelSnip) which a BERT reader can then scan to find answers. Relevant snippets are portions of the retrieved document that contain exact match results for the search query.
4040

41-
- **Expander**: Methods for generating additional (relevant) query terms to improve recall. Currently, we implement Contextual Query Expansion using finetuned Masked Language Models.
41+
- **Expander**: Methods for generating additional (relevant) query terms to improve recall. Currently, we implement Contextual Query Expansion using finetuned Masked Language Models. This is implemented via a user in the loop flow where the user can choose to include any suggested expansion terms.
42+
43+
<img width="100%" src="https://raw.githubusercontent.com/victordibia/neuralqa/master/docs/images/expand.jpg">
4244

4345
- **User Interface**: NeuralQA provides a visual user interface for performing queries (manual queries where question and context are provided as well as queries over a search index), viewing results and also sensemaking of results (reranking of passages based on answer scores, highlighting keyword match, model explanations).
4446

docs/images/expand.jpg

55.1 KB
Loading

neuralqa/server/routehandlers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ async def get_answers(params: Answer):
3939
self.reader_pool.selected_model = params.reader
4040
self.retriever_pool.selected_retriever = params.retriever
4141

42+
# print(params.query + " ".join(params.expansionterms))
4243
# answer question based on provided context
4344
if (params.retriever == "none" or self.retriever_pool.selected_retriever == None):
4445
answers = self.reader_pool.model.answer_question(
@@ -47,10 +48,13 @@ async def get_answers(params: Answer):
4748
answer["index"] = 0
4849
answer_holder.append(answer)
4950
# answer question based on retrieved passages from elastic search
50-
else:
5151

52+
else:
53+
# add query expansion terms to query if any
54+
retriever_query = params.query + \
55+
" ".join(params.expansionterms)
5256
num_fragments = 5
53-
query_results = self.retriever_pool.retriever.run_query(params.retriever, params.query,
57+
query_results = self.retriever_pool.retriever.run_query(params.retriever, retriever_query,
5458
max_documents=params.max_documents, fragment_size=params.fragment_size,
5559
relsnip=params.relsnip, num_fragments=num_fragments, highlight_tags=False)
5660
# print(query_results)

neuralqa/server/routemodels.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class Answer(BaseModel):
2323
reader: str = None
2424
relsnip: bool = True
2525
expander: Optional[str] = None
26+
expansionterms: Optional[list] = None
2627
retriever: Optional[str] = "manual"
2728

2829

neuralqa/server/serve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646

4747
@api.get('/config')
48-
async def ui_config():
48+
async def get_config():
4949
config = app_config.config["ui"]
5050
# show only listed models to ui
5151
config["queryview"]["options"]["relsnip"] = app_config.config["relsnip"]

neuralqa/server/ui/=7.0.0

Lines changed: 0 additions & 8 deletions
This file was deleted.

neuralqa/server/ui/package-lock.json

Lines changed: 0 additions & 15605 deletions
This file was deleted.

neuralqa/server/ui/package.json

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616
"react": "^16.13.0",
1717
"react-dom": "^16.13.0",
1818
"react-router-dom": "^5.1.2",
19-
"react-scripts": "^3.4.3",
20-
"react-vega": "^7.4.1",
21-
"vega": "^5.14.0",
22-
"vega-lib": "^4.4.0",
23-
"vega-lite": "^4.14.1"
19+
"react-scripts": "^3.4.3"
2420
},
2521
"scripts": {
2622
"start": "react-scripts start",

neuralqa/server/ui/src/components/Main.js

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ import React, { Component } from "react";
1111
import { getJSONData, sampleConfig } from "./helperfunctions/HelperFunctions";
1212
import { Route, HashRouter } from "react-router-dom";
1313

14-
import ExplainView from "./explainview/ExplainView";
1514
import QueryView from "./queryview/QueryView";
1615
import Header from "./header/Header";
1716
import Footer from "./footer/Footer";
1817
import { createBrowserHistory } from "history";
18+
import TestView from "./testview/TestView";
19+
// import TestView from "./testview/TestView";
1920

2021
const history = createBrowserHistory({
2122
basename: "", // The base URL of the app (see below)
@@ -85,7 +86,7 @@ class Main extends Component {
8586
<Header data={this.state.config.header}></Header>
8687
<main className="container-fluid p10">
8788
<Route exact path="/" component={mQueryView} />
88-
<Route exact path="/ex" component={ExplainView} />
89+
<Route exact path="/ex" component={TestView} />
8990
</main>
9091
</div>
9192
)}
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
import React, { Component } from "react";
2+
import * as d3 from "d3";
3+
import "./barviz.css";
4+
5+
class BarViz extends Component {
6+
constructor(props) {
7+
super(props);
8+
9+
this.grads = props.data.gradients;
10+
11+
// this.minChartWidth = 900;
12+
this.minChartHeight = 250;
13+
this.minChartWidth = this.props.minChartWidth || 800;
14+
15+
this.brushHeight = 60;
16+
this.barColor = "#0062ff";
17+
this.inactiveColor = "rgba(85, 85, 85, 0.586)";
18+
this.initialBrushPercentage = 35 / this.grads.length;
19+
20+
// window.addEventListener("resize", handleResize);
21+
}
22+
23+
getLabel(d, i) {
24+
return i + "*.*" + d.token + " *.* (" + d.gradient.toFixed(2) + ")";
25+
}
26+
27+
componentWillUnmount() {}
28+
29+
componentDidUpdate(prevProps, prevState) {}
30+
31+
setupScalesAxes(data) {
32+
let self = this;
33+
this.chartMargin = { top: 5, right: 0, bottom: 0, left: 0 };
34+
this.chartWidth =
35+
this.minChartWidth - this.chartMargin.left - this.chartMargin.right;
36+
this.chartHeight =
37+
this.minChartHeight - this.chartMargin.top - this.chartMargin.bottom;
38+
this.xScale = d3
39+
.scaleBand()
40+
.domain(data.map((d, i) => self.getLabel(d, i)))
41+
.range([this.chartMargin.left, this.chartWidth - this.chartMargin.right]);
42+
43+
this.yScale = d3
44+
.scaleLinear()
45+
.domain([0, d3.max(data, (d) => d.gradient)])
46+
.nice()
47+
.range([this.chartHeight, 0]);
48+
}
49+
50+
createSVGBox = (selector, height) => {
51+
return d3
52+
.select(selector)
53+
.append("svg")
54+
.attr(
55+
"width",
56+
this.chartWidth + this.chartMargin.left + this.chartMargin.right
57+
)
58+
.attr("height", height + this.chartMargin.top + this.chartMargin.bottom)
59+
.append("g")
60+
.attr(
61+
"transform",
62+
"translate(" + this.chartMargin.left + "," + this.chartMargin.top + ")"
63+
);
64+
};
65+
createBarRects = (svg, x, y, data, chartclass, transparency) => {
66+
svg
67+
.append("g")
68+
.attr("class", chartclass)
69+
.selectAll("rect")
70+
.data(data)
71+
.join("rect")
72+
.attr("x", (d, i) => x(this.getLabel(d, i)))
73+
.attr("y", (d) => y(d.gradient))
74+
.attr("height", (d) => y(0) - y(d.gradient))
75+
.attr("width", x.bandwidth())
76+
.attr("class", transparency ? "strokedbarrect" : "")
77+
.attr(
78+
"fill",
79+
(d) => "rgba(0, 98, 255, " + (transparency ? d.gradient : 1) + ")"
80+
);
81+
};
82+
83+
drawBrushGraph(data) {
84+
let self = this;
85+
this.brushXScale = this.xScale.copy();
86+
this.brushYScale = this.yScale.copy().range([this.brushHeight, 0]);
87+
const x = this.brushXScale;
88+
const y = this.brushYScale;
89+
const mainXZoom = d3
90+
.scaleLinear()
91+
.range([this.chartMargin.left, this.chartWidth - this.chartMargin.right])
92+
.domain([
93+
this.chartMargin.left,
94+
this.chartWidth - this.chartMargin.right,
95+
]);
96+
97+
const svg = this.createSVGBox("div.d3brush", this.brushHeight);
98+
99+
this.createBarRects(svg, x, y, data, "minibars", false);
100+
const brush = d3
101+
.brushX()
102+
.extent([
103+
[this.chartMargin.left, 0.5],
104+
[this.chartWidth - this.chartMargin.right, this.brushHeight],
105+
])
106+
.on("brush", brushed)
107+
.on("start", brushStarted)
108+
.on("end", brushEnded);
109+
110+
const defaultSelection = [
111+
x.range()[0],
112+
(x.range()[1] - x.range()[0]) * self.initialBrushPercentage,
113+
];
114+
115+
svg.append("g").call(brush).call(brush.move, defaultSelection);
116+
117+
function brushStarted() {
118+
// console.log("brush started");
119+
d3.select("div.barviz")
120+
.selectAll("text.textlabel")
121+
.attr("class", "textinvisible textlabel");
122+
}
123+
function brushEnded() {
124+
d3.select("div.barviz")
125+
.selectAll("text.textlabel")
126+
.attr("class", "textlabel");
127+
const extentX = d3.event.selection;
128+
// console.log("brush ended", extentX);
129+
if (extentX) {
130+
// const selected = x
131+
// .domain()
132+
// .filter(
133+
// (d) =>
134+
// extentX[0] - x.bandwidth() + 1e-2 <= x(d) &&
135+
// x(d) <= extentX[1] - 1e-2
136+
// );
137+
138+
updateScalePostBrush(extentX);
139+
140+
const svg = d3.select("div.barviz");
141+
svg
142+
.selectAll("text.textlabel")
143+
.data(data)
144+
.attr("x", (d, i) => {
145+
return (
146+
self.xScale(self.getLabel(d, i)) + self.xScale.bandwidth() / 2
147+
);
148+
})
149+
.attr("y", (d) => {
150+
return self.yScale.range()[0];
151+
});
152+
}
153+
}
154+
155+
function brushed() {
156+
const extentX = d3.event.selection;
157+
const selected = x
158+
.domain()
159+
.filter(
160+
(d) =>
161+
extentX[0] - x.bandwidth() + 1e-2 <= x(d) &&
162+
x(d) <= extentX[1] - 1e-2
163+
);
164+
165+
d3.select("div.d3brush")
166+
.select(".minibars")
167+
.selectAll("rect")
168+
.style("fill", (d, i) => {
169+
return selected.indexOf(self.getLabel(d, i)) > -1
170+
? self.barColor
171+
: self.inactiveColor;
172+
});
173+
174+
updateScalePostBrush(extentX);
175+
update(self.grads);
176+
}
177+
178+
function updateScalePostBrush(extentX) {
179+
let originalRange = mainXZoom.range();
180+
mainXZoom.domain(extentX);
181+
182+
self.xScale.domain(data.map((d, i) => self.getLabel(d, i)));
183+
self.xScale
184+
.range([mainXZoom(originalRange[0]), mainXZoom(originalRange[1])])
185+
.paddingInner(0.1);
186+
}
187+
188+
function update(data) {
189+
const x = self.xScale;
190+
const y = self.yScale;
191+
const svg = d3.select("div.barviz");
192+
svg
193+
.selectAll("rect.mainbars")
194+
.data(data)
195+
.join("rect")
196+
.attr("x", (d, i) => x(self.getLabel(d, i)))
197+
.attr("y", (d) => y(d.gradient))
198+
.attr("height", (d) => y(0) - y(d.gradient))
199+
.attr("width", x.bandwidth());
200+
}
201+
}
202+
203+
createToolTip(svg) {
204+
// create tooltip
205+
let tooltip = svg
206+
.append("g")
207+
.attr("class", "tooltiptext")
208+
.style("display", "none");
209+
210+
tooltip.append("rect").attr("class", "tooltiprect");
211+
212+
tooltip.append("text").attr("x", 10).attr("dy", "1.2em");
213+
// .style("text-anchor", "middle");
214+
215+
return tooltip;
216+
}
217+
218+
drawGraph(data) {
219+
let self = this;
220+
this.setupScalesAxes(data);
221+
const x = this.xScale;
222+
const y = this.yScale;
223+
224+
const svg = this.createSVGBox("div.barviz", this.chartHeight);
225+
const bar = svg.selectAll("g").data(data).join("g");
226+
227+
bar
228+
.append("rect")
229+
.attr("class", "strokedbarrect mainbars")
230+
.attr(
231+
"fill",
232+
(d) =>
233+
"rgba(0, 98, 255, " +
234+
(d.gradient > 0.5 ? 1 : 0.5 + 0.5 * d.gradient) +
235+
")"
236+
)
237+
.attr("width", x.bandwidth())
238+
.attr("height", (d) => y(0) - y(d.gradient))
239+
.on("mouseover", function () {
240+
tooltip.style("display", null);
241+
d3.select(this).attr("fill", "lightgrey");
242+
})
243+
.on("mouseout", function (d) {
244+
tooltip.style("display", "none");
245+
d3.select(this).attr(
246+
"fill",
247+
"rgba(0, 98, 255, " +
248+
(d.gradient > 0.5 ? 1 : 0.5 + 0.5 * d.gradient) +
249+
")"
250+
);
251+
})
252+
.on("mousemove", function (d) {
253+
var xPosition = d3.mouse(this)[0] + 10;
254+
var yPosition = d3.mouse(this)[1] - 20;
255+
tooltip.attr(
256+
"transform",
257+
"translate(" + xPosition + "," + yPosition + ")"
258+
);
259+
tooltip.select("text").text(d.token);
260+
tooltip
261+
.select("rect")
262+
.attr(
263+
"width",
264+
tooltip.select("text").node().getComputedTextLength() + 20
265+
);
266+
});
267+
268+
bar
269+
.append("text")
270+
// .attr("fill", "white")
271+
.attr("x", (d, i) => {
272+
return x(self.getLabel(d, i));
273+
})
274+
.attr("y", (d) => y(d.gradient))
275+
.attr("class", "textlabel")
276+
.text((d) => d.token);
277+
278+
let tooltip = this.createToolTip(svg);
279+
}
280+
281+
componentDidMount() {
282+
let barvizElement = document.getElementById("barviz");
283+
barvizElement.style.width = this.minChartWidth + "px";
284+
285+
this.drawGraph(this.grads);
286+
this.drawBrushGraph(this.grads);
287+
}
288+
289+
render() {
290+
return (
291+
<div className=" barvizcontent ">
292+
<div className=" ">
293+
<div id="barviz" className="barviz "></div>
294+
<div id="d3brush" className="d3brush "></div>
295+
</div>
296+
</div>
297+
);
298+
}
299+
}
300+
301+
export default BarViz;

0 commit comments

Comments
 (0)