From d66ba4a59af7e869d341498ba8db1285df4fb0ac Mon Sep 17 00:00:00 2001 From: cheng-wang2002 Date: Sun, 28 Jul 2024 04:58:28 +0800 Subject: [PATCH] Enable rag inference and evaluation --- contrib/rag/README.md | 44 ++++++++ contrib/rag/build_corpus_index.py | 87 ++++++++++++++++ contrib/rag/corpus.dpr_index | Bin 0 -> 41533 bytes contrib/rag/corpus.txt | 1 + contrib/rag/rag_evaluation.py | 130 ++++++++++++++++++++++++ contrib/rag/rag_inference.py | 160 ++++++++++++++++++++++++++++++ contrib/rag/run_rag_evaluation.sh | 58 +++++++++++ contrib/rag/run_rag_inference.sh | 56 +++++++++++ 8 files changed, 536 insertions(+) create mode 100644 contrib/rag/README.md create mode 100644 contrib/rag/build_corpus_index.py create mode 100644 contrib/rag/corpus.dpr_index create mode 100644 contrib/rag/corpus.txt create mode 100644 contrib/rag/rag_evaluation.py create mode 100644 contrib/rag/rag_inference.py create mode 100644 contrib/rag/run_rag_evaluation.sh create mode 100644 contrib/rag/run_rag_inference.sh diff --git a/contrib/rag/README.md b/contrib/rag/README.md new file mode 100644 index 000000000..29866cf44 --- /dev/null +++ b/contrib/rag/README.md @@ -0,0 +1,44 @@ +# Retrieval-augmented generation + +LMFlow now supports retrieval-augmented generation. We offer four different retrieval methods which include DPR embeddings and BM25. Also, any model supported by LMFlow can be used for generation. + +* DPR(Dense Passage Retrieval) Embeddings: \ +https://arxiv.org/pdf/2004.04906 +* BM25 retriever: \ +https://python.langchain.com/v0.2/docs/integrations/retrievers/bm25/ + +## Requirements +Faiss library is required for dataset indexing. +``` +pip install faiss-cpu pickle rank_bm25 +``` + +## Build indexing for custom corpus for retrieval +If you want to use your own corpus for retrieval, first use `build_corpus_index.py` to build an index of the corpus embeddings. We offer one type of embedding method `dpr`and one retrieval method, `bm25`, which also requires indexing. + +Below is an example that utilizes OpenAI embedding to index a corpus using '\n\n' as the splitter. + +``` +python ./scripts/build_corpus_index --corpus_path='corpus.txt' --splitter='\n\n' --embedding_type='dpr' --data_index_path='corpus' +``` +Then it would save corpus and corpus index to ```corpus.dpr_index```. + +## Inference and Evaluation + +After building indexing of corpus, you can run the script `run_rag_inference.sh` that user can directly input question, and the script `run_rag_evaluation.sh` that user can input the path of dataset. + +Here are two examples of each script. + +``` +bash ./scripts/run_rag_inference.sh --retriever_type='dpr' --corpus_index_path='corpus.dpr_index' --top_k_retrieve=5 +``` + +``` +bash ./scripts/run_rag_evaluation.sh --retriever_type='dpr' --corpus_index_path='corpus.dpr_index' --top_k_retrieve=5 +``` + +## Known issue + +Current `build_corpus_index.py` has memory issue, since it would load all corpus into memory at once, so if the size of corpus is larger than your memory, the process would be broken. Our next step is to enable our program to load corpus piece by piece, so that memory would not be an issue. Also, + + diff --git a/contrib/rag/build_corpus_index.py b/contrib/rag/build_corpus_index.py new file mode 100644 index 000000000..027a662ce --- /dev/null +++ b/contrib/rag/build_corpus_index.py @@ -0,0 +1,87 @@ +import pickle +import os +from transformers import AutoTokenizer, AutoModel +from transformers import HfArgumentParser +from dataclasses import dataclass, field +from typing import Optional +import torch +import faiss +import numpy as np +@dataclass +class RetrieverArguments: + corpus_path: str = field( + metadata={ + "help": "Please specify the path to the document corpus." + } + ) + + embedding_type: Optional[str] = field( + default="dpr", + metadata={ + "help": "Please specify the type of retriever: bm25, or dpr" + } + ) + splitter: Optional[str] = field( + default="\n\n", + metadata={ + "help": "Please specify the splitter of your document." + } + ) + + data_index_path: Optional[str] = field( + default = './data/corpus', + metadata={ + "help": "Please specify the name of data index name." + } + ) + + device: int = field( + default=0, + metadata={ + "help": "The machine rank of gpu is used." + } + ) + + + +parser = HfArgumentParser((RetrieverArguments)) +retriever_args = parser.parse_args_into_dataclasses()[0] +with open(retriever_args.corpus_path) as f: + text = f.read() +texts = text.split(retriever_args.splitter) + +if retriever_args.embedding_type == 'dpr': + model_name = 'sentence-transformers/facebook-dpr-question_encoder-single-nq-base' + device = torch.device(f'cuda:{retriever_args.device}') + + tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/facebook-dpr-question_encoder-single-nq-base') + model = AutoModel.from_pretrained('sentence-transformers/facebook-dpr-question_encoder-single-nq-base').to(device) + encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt').to(device) + + with torch.no_grad(): + model_output = model(**encoded_input) + def cls_pooling(model_output): + return model_output[0][:,0] + embeddings = cls_pooling(model_output) + + + dim = 768 + index = faiss.IndexFlatL2(dim) + index.add(embeddings.cpu().numpy()) + chunks = faiss.serialize_index(index) + with open(retriever_args.data_index_path+'.dpr_index', "wb") as fp: + pickle.dump(texts, fp) + pickle.dump(chunks, fp) + +elif retriever_args.embedding_type == 'bm25': + with open(retriever_args.data_index_path+'.bm25_index', "wb") as fp: + pickle.dump(texts, fp) +else: + raise ValueError('The embedded method is not implemented. \ + Please specify the type of document embedding as one of the choices, [dpr, bm25].') + + + + + + diff --git a/contrib/rag/corpus.dpr_index b/contrib/rag/corpus.dpr_index new file mode 100644 index 0000000000000000000000000000000000000000..8f075784644cd216d0154e5ad89eb6b6170170ad GIT binary patch literal 41533 zcmZs?c{r6{_dkx1Iie)wG^u20B)RuGnu$~@Qc>n1Q-+EL^GxQU(j-uM;Yo=`AMo&&xJrW}?5T zH^-ySFRH$=S=^+lk}i6>+b)LRQtdHA1`}>lK+3k z%v`c=#Vn53zlaL-<~TTWoZMXmHUxSGj3^;cJtJ`qXQL0t-pR?&+23EFmb*Y}fB!&d zZO$I&5uv_50za)=vSDOmj=(EA4)*>6Z~T9NU$^2v$Y&(RjEv?42Dp2=2MZt=aU>_e z&t3qr6UWD!13%|of$skV zhkzvi5zBE46kyEpw0H3F%SiO$mD;z0h|7znt*yE zo)A#g!PEIaV0n59eD@zB|3i}z?%VtNdJ2MUWRw7Ed%;jAL9!5}hkvq#fB-uM@NN3X zs1e5ta)p2b|ElA>+uk!!FoTb`z`KFoPJ*~}^!Z0~cY%KXf$jnJ|EHh7K=MBn7Ti3N zNB#-ojKm7g#*qLJXzJkZ{g0#(-;H?Yf6>?euag`{C_Ivi{>OGN`<((ljeu@1NIU|C zMojrPE&Vf{XCIW2D6?IVivg4aJ6GogSChq$FW0#!`+uAC<{N=QidU#$dJ`TxB8_ph+< z$j2G~{!h-;iAtpG$AYN{52h&q5xxY*>oei*!vs>7Cl9qR{4lywgwOA3#s_t#aARUR z9$T}J&U!2k68Dd?m$ukIE^jPG)o^guIt}i#xw<$}k%QvG>2R)gFX|hb;DjGNR82mf zRlU@;_{<%7vh*FLeMAnn#cv}KT`cV>$)J(@ObC-WK&>@L@f&A3VCd&<^o?N}Irlz` zR*kU)*G=b8r;yJ^G6s*s>TpJDwBuJ$9fP;R(kX2mCOkjBnLqNwox}*QF%* zzt38s(ohyL|6T#vyIj3@q&oBxfot*fUdQVBwSp80Qdz`<6V0&C-v+ zR91+#bU%ak-S@$7R6M-0se|Ur1Z-9qq28gZ;IwrPF6sJA`aasgnA2IXZT%#k_Y*D3 zJ6%jGqFm9?WDShl>J3&4l33B8cskEE6Su9MOPfMec?ZTy;S|nzQW~L)_G@;LYY9HM z$VneP@5e%;Q!}0U`J$25wr6lQ8flNL6~qiq#R;Ex;4?EX_$56N4vZaz$3OW%j%pm? zpZU!ick_tBRjkxLycEx-YSW`CAK`ZRO*AMuN%)xzd;9!6sBYf~k;Yr0B;*2{(YYNq z@2DliQsWuh6{T1{W*ThRXhM&feSp?ynaqwFb=spk#G7f8g6qQbFivqk{w$G%qnC1V zMQ;av z(YXWkbRxh@#|PM+c`)nNMsjwE5}e;wiTBn>BQq6=+`S=q6J-fL$&+E9_gSKplaF_< zxwF~Ek3juU2FNSzpuQ$Y@O0K6Xj^;|g{!ksJn}NSb@=n^SXBGDg|y3@fwlS{cur$4;)JqA?Dc*%eBZ!>{zE?W!HN=c{Du!Cm43h< z6HY-afJa`f}EzqB(-5c}!6#IezU=C59lwpZ^$US~C1Ic+83=!N6N zsa1N2tHI8oglak81-na9u_p`)2i#vj=gDX~ zcq|7e9{x_tOBU19!eJo(ln|TXucZ6eR~)zXB=)7sbB(`uGB(+|MBk#6{ZZsX{Nysp zN~g~1+Rc-xO2;NJ^)O=-A8ZGwjyaIM`6oG^8i56iO<}*14Ix8!cuIj8;5y|8oo^uo zUJeoXc&0R0!=)PUxtPITAuUo>{*Zckg@fq(Nc=NDvx+fZ2IDgc8SJtolA6cxp4LQ6 z4Nj#e4z7UtYP&IIoFRHvoJL>M15l`~j+yxzKq2ukd^;KnUSV92{&odZ_PrvBAFqL? zts7K|Sh62vRPatUAgBK&jgR^Z7IQ{{N-CvVF^ll*PY zI9Jn}x-E>PvvT~%jkR{{M}8#pX3hbkcis`BhXzPmKrKBLFp;}><{IL9eli@>xdofN zkJHJ#sW4rj+dI+Kc(5!Q1NBlUla>h!3lI(`ndA8Wv)HYf2xCPgf!%oyH(#!2rO%xs zi(Xf-n{FOKX&Vl>?@XY%+x_6p-a}Y8<{*ywc@9g8y^T^GgNSU9Da5r{aMcdJWb511 zFs4ERYB#Ci5=h6yZJ$W}iwr2w-9yAe=3?WTC%m7#r}A~P*U(w(gt;AauThmb!BmpI zCUqAKfq7#?o=?igNn?^p^yX7|v}+#h;K-7ipLapm!vK!l_M+zJ;!(3yfzdtT0k55t zNp43uIw;+*RFpqU3!6lE=@Vz*t{=YCWREoLQ&r$!nU#emmglMd?#jpz{C(WQsl`VxG~k857>Mq0!Io(Q zkfjMd;`AQJr_iGegze!t4!#jP&~^t8htMtrh4|~ z}@0DM%@9JpxR15l9x30i5@2F214VLH2O5$28&EJ@J}Pq3{`X5?tK9h z7POM8C)erVsIw%x=@xkxzmVZzDRymI$9ue9jPFnu4qIt1u50~Pe$}(hA}hqV=7}|>yExIAJA`>2;H#r6rNIFkJ}fnpk2>2VdA3#vS~{ZHGlKW zNbA#WR;0$7zybiAY8qG?1$dT%GM61WL}m_0tluz$j0@SJQ1C5Nt} zLqQb8Tf0M6@J1#f(GtIAeuv!>s@UK>4GR`3fF50lzO6#E{%;O3f^}MP3~ia44?~Z_h^^-dw0N-`@?`GQt4jJf(IXr8|2+#CLtjBkt{GJjWPR3JLN@I;PGL;m*{tr)PaQaJgtX^vkExG@&67AA5#Qs}iP;Ez?lFDTN$- z@dSF0-62-SqfsMp9tiiukRQK>$(xT|L{DWpG%1aRKzA13f6FA3#+cF$yM6ecHwp7r zdC)JyUyKgV-Gf*2q)__eS#Vrt0w=#);NNZaB)%dOw?top;3b}rd-5)5Zto%Ig@4jX z2fInJ{Bg!QRf4;I(M42l)1o^?AHmb#JtR~yn-turq#unIFqK|TxOLldINlQvls}0a zIk*Cw9Lwp2rBd|p$BB5bY8`1gdIncY>wrzJ8HrHYfnK8~U|=G)H@;$zPgCP}Ei1+`H$s@D7fxa1ZGAXDRTyjAmht?b{Gj2{P58bs9~^39 zX;zXKo+;f9d1o#7H~ty|>*PhPI;?5U)UWjB&H*^Q<16@`{77RBMQ~-q8dm&WIk`}@ z2`zdeiPMrY>I}vhBUop|cPqiH-0yI;F%?@K55VIlXXHK=2U&?~d>?#-3~GO8{vKNb z#rhs_w>KD!^Mv`%E6XvjNtHf`@+T4OD_&8pKDk&hmRyMmgQTdX?30yN=yphs`|Fb| z=zrb`MxvqcvPS}5DO|_o;)_ruaRTM`hvDESS$xoMNH$HqO|&1+!A(IMS+`?}_}e!f zWUQvM*TgfJFBjtAuvit%ef*g!Zkh}^E&(KdzF@z#o`=&9goAYCVf_8+9rb%6&&>5r zq8B9VX{Ymb7;er)k#Yhucz!WE^$9+h8btn4BjLhD$n#()Lca zWwI<;Y8S~y@qIzlB#HcOwS#>pC*!kq&Gg_ZY3SD<1t}IzAfix&5xo&07RR8Hsx~^_ z*a*HKw_xzFBAn2*B0qLbVwm31&=M{SdD8cBd8ZT3Tl9-$y_v$V+WMCUw8mg`yA53T zA0~ebZeYo)K9J5%#-9d4bmLw>2r3LPvKy;UbO%3^%ljWAXPpb|ve#wg6-{`feEp!W zc{{qaXMyedC9ucq2~liSAqpLnAayVwuH9Kei?+Om8;7$<*t`q)rgo5oKb;4Q9w(B) zj2-w|C<1L_&cMWAYy2XJ)BCTuw0rt<6i!PrN`1J2rZ0R*59yxismJBtw!tmxvN)W3l0>}8XaA^6;R2x`a%8xLD!6?@xvi^NX}en*Mz|D2}yLq zy8~e7SwOogM7iTHivYTnlaK?2$o(@9@xW6O{xyn?X&1(iK6O+w&;ia>Eu^i>Z_`T! z@2J9&ZLsL}C7koS!${*@8t9h=V>Dt_;0F`faYK#wnzNXQFL+Nka6Dl5uPfx<9*V|O zyzzBVH5)3}$Di5g3k^C zBpRqTg&Pv;MFXuxx!u!3;M0pRxVG*PoK$`cPF@YLQ20Bw&oh9?R$#oE|Eh214HEpPFz)bTgps-H_JViHw z-2ol2%6Nh0Qd3Z&EEdXt-Jxww9OhQDE?29!krw`ap&ce~=7-)P~3IVf6{FMA>SaQ7S8 zf0#IKAHee!*)(Jz3FFUOga7R-l*5^iH3#ZRU)?l5>o5t24)2A;8F94lJqKKNwLv@x&QnaZ}1nB0l46#l|1{Y%h${48ADGZqSV-6etj9Ykxn z2X@^agN(HriMky~EpP6C&J|afvFqQkp!3*I<60umW2wh{< zu%><`j$e8X*M5lrrJ!n*bsGX5r^jT0Q6a9;egOMB@4#-gTHGPw0?$o-&~SYakGc^o?~gZqF@Y<7SlN_D!;M>}(^f3+N_7 z>XXRx24A{w@Bl9J5(WkLB9v=TW5SKoF>bVDiGT-(xEzvTzf)?+?s%%23d-u2|w z_s4YX!A3q8UE4>64wC_Mk_f zC_HXIgGKijL)O|nIy9-7c9_M$R%Ob2e_aaHMlVBK%XEzWdXw(^ZVp#0+}N$DE9j7M z0AXJCkX;V1iBhj2EXi04e}Bni|B(}Pw9!+#%C8>7)|8N2F-~BaK1k%+qiOQ)8zk>S zI}5AalX zGTh~B^FJQ*hDE<0Qdhlvs58%k&L%yE@vf(f%5y;5r2{fnNy3R!D&$_SHW1m#@IGP= z7Ta~8)yqS4cwGW6j81~$VM)GC|4Llo{eV;$HW@8{;SbZltigIFj6V9+iw0A|skHJO z+_Gab7#4oRm|}ObrfeS(s+$0N0utasNEl_i*0W|__sJy75+=qg3`-WKqyFT*$ZRo3 zTe~5K^EMvVo=%2@{o`TFPDPB~mW#hH#iFUgdMKV{#GCn`m-hN7!>I{3Nwo1PJRMy} zDm@Ef#>pY3(D?v;S`|T7PqTncyXL`g&U(1H>^kv#5=AYZuBA(Fda)Ucdcb(kDYDtF zlRPjnqOSQ%;ez1VrPFZnMR~KJP)bwH<{X4nm^$ICPFFq%xn5 z;J|ia?gFDQ5|y(awJKC$QyCY=<@3nz))S;8%?9MhUMG9^n~}t#I8vSyrBIZ614d78PK^ega_rziAIG3 zEmkkX7e!^T(sUg>nW%xT%K$XJmHELTy);w~Ujp6A8zeMAi5>s(T8-iI);I zpnWb}ca`HMH*G=Q^wD7PEsYdg+{8J;JUV4S7k|t}C@;EqW$i1OP#YN>n|GWl?exI@&9i}Qh=oZr z6u8~-1yGWn&T80eL-cofzRDd3Je!pX2X(e0zUaXI_xTv}WC?w-FB9HfUx>?MGSE~n z7{nH7k~1dFR98a+v*ef4b>3saE@~MjXdK2X5mNkHt+v=3D24rtcH`aRWu##L4Z8N% zb*y}G0c|YQ$@abh-2Qzhy}Vcf27)u;%tfSDKk8^(S3L|x$w7_7FIpIt3zk2<`5_*m zn0x&vE1?vE&(5AiB{fZM?wd2P^O7lqm#k463@brWNDOx5*#oiwT zn`27oJ>C-1Rz1V19Vf}rT2b16z!{p4iSX~Y%E1}wY3RbTGu@NP($mNBR-gj6d{P+mMKucyPSgPZRy5=VoWhO# zM5tU_hP>|gbfVq?*2dD5kjIC}P7g6m)@h(J)mKU2w>`ZV*2A z0ft_ChPNNqlZ@#xkkk5s_PKJQrO+ESihRjcvF-3_XQUB4XUM1rCwUn@Ng$?wl-@n1 z0!fKd7)sv_%dl#CuwXyY_EPsi&7%bc=hd*w$ zP=oDEPR&!{~CxexnsskrPuCbf&qUc4T@o85)d@M#Jwi zd6*ZmhK}m%rZGb8)FO=~y1E)T#&#jDNj!qX!ZD~GxRrdneh3X$MSx9d2bG!rjd}WK zFN|Ht$MqFLcwGMr7`;8l3qJIS@Y`zfPm3>1UT1<+I>OQU>IoEI83v)63b3L$2@ZWW zpoLdsiC>cewdL%Dd&gsG4C=uCI(4Z2#qfQ)^c;29Ty;o?ob~!zFOb2!(eq+m`Usp%))ag{iO2}RQ$mnmt2{NsymdxWm zhpGlw9><1-_|IpdDC;l%#g@S1f+Ac#6bp&=qWnDhB`~czj808l3!k6%k`n3OXZo16EBMhmv2kLHhP$C^XC^+x7r`E_Mq}iZ)kMrzyOa~XgPO)&aAh>Ue!pjHtT{fW8H99N-F&Fu4R|LiiPoW$8i_v z#nTr;KiK`V@^HcBzwr2Y9f(_;U^f(o!(6E?0NQzQN9io$X-zDAdz^Xgw~NNS8v~Nm z32x~BMpmSV&K5jx>~$jvBWHy692?jkJq2ToE-_y0H+*P#7KD#v!1&WI*p~4}*%aB+ zUT#6f_^dC(Em5BmniU}oe#cK#MyJhUYW1GJ1$??5Ab2}s5LAC8e)xmcXj z!K0Gz9LR&!dQhrlflTlmtW-~gn=uJgYE1=g+mtrk{Hc_s$02_q^Yu~@ta zVNzHtW-N-u=F|=*NMwkvsjNZcXG(CnaW7%cibGnI1TK)Qr4mOzf^z+N`Ym8G-{+PI zH8C=T_J>R{FJU2`4faIsu5}pv$_8cDM}kkb1RNPUOm>Ouq3Luzv|qT5 zW~l4{hucZWR{SyYunxvSxubB*>nd4kUQN~bQM7r$4x7IE!l{@sxOMR-qf;sm=m#|; zu%7uGo<6%q&F(G1*w^}~+bjW}OwX_xS&B5WVhTwLdq%Qre$eB^$xva_T{YS#i*9TU zqqgcS**)sL;LLq9RA?TCu?E#-@9$75yG@q}FVe*edeL;}n=opzO;4vs3o^mTxq%waI)zTw26*FbBwc#M30^I| z3-YA_B=CAVq@)Szu>RUKOQjxtOj2k`;ScILaRW|ER3Iw_`<2q0h2ZyS3D00m3f|e2 zfRTRpXtA&_SdO|)Vn)@|r;nBK_GcOBvc85Z6vxAy=>_oetv^WW6jH4(-Q@RiTgW`1 z24=gWAiU`bbe&2e5tV|P@(vB;|2-(!6R5-A#3STIZzy*4DZrAjFnY{#Avt|+JJEl9 z6Q5e_VxNcllPlKU;2`x8tF_L+SF2F0`5{Mm7u&EWt_=pJ<Cy=fSBXl37`WVNp_|^Wg9z17 zX8!b6VxS#GmV2b4U*%osG`mXsM)|U1qs9OoGZ8-8Ttl(&Ix@8Xhx6=0F zd<|Jr_N|YkZ`lfRuTmg5)BydKFM{`BE6KV^U%*i>lCo=B(8lWx{8;V**nYuqR}@10Xd$~_W{ z-sH$YimL?)iP!@!u8DZf?+TVxmy+JQB4Fi^MU2#Upu@`#sHLnJdW3;GDIHj$~JW$eZ8p#ek9GHViQNh=fD!dv)K$yOe>(uQ;9rV9**C0e&eD;S%O}NA4L8{H*%HC>EEgn zp2cHbl+QXxjiy&pJRJiQY>uESQ2+%YW9)7}g2F$X>5Zm(-~pb;5&^C-%EGQ}^Z5$L9*kxkbcaK^VL4OgzqR*Z_=cE|J`xLcS^K0VCN>zB0a$Vq!tgX}x6S zzG-nEL^jY-jVk!2?kkk<9t%23ONsvc%P>QE3LU5)jUOs65ZMYzV(uG3L^u!0iupHS z?X*d7Zm9x3nOK1#(Ho3nE$f&cIie_VF1hyQ6)ka9!t(jDkajl?lJ{Pu#iwFXah?HA z*11fA^v0la_XnQ8ku3MdtmSy1M+BZnBtm{(G-FV%NdophVatCPqLfNE?fX;;N7Wt3 zqhNd3I8U(l5FMZmMc!=B5T&vUy2#PxAHbIX9s(^gaGR7EH{oC(yJU3|tG7E5(rq-isTdAV_Jr1yF%S~`!$ z7k}oEwQ9|jiJpojL+fe%%FnRub0H{{zNAmPq)^o54;?Zrf>Ec_Ne6w5d-a8I+1EJo zx2lc~TGwNc!zMEKTm?(doyN1p*=+9icC@dzglRw4(Yx$<5Rq#py9z&( z15GthrTT}uH7C+lT@Q)K@Ld=%pND}_U(x7JH^|Oy!5b!s%f4M=CP}}-)%y9=R%;b^ zjJJ%@jRhjyKkb*uuRmw->nv-yTdu)uf9QcrZpp*Bm$~qXuWV$ucMEg<_E&WFxk52f z57Vrqan@5q{(_ziqBpOK`FLS8O=-MSRU|2n3zOyGUHS}MC~=wCbq=t-x^ghX$CSR2 zQiWb#CZ4i8Npy~ck-6HkWS)#Rlo|`_E6d{rH9l$XQ_EyBZJa&%St1KBUnkM=is5LN zYELx_tuR*JfcU9bqKD&V8tY~P2M--a?m;!~4DMBWU+Nr6E!+YH^1HC4!kItJeq(>4u_CE7S6=P zb5JL54s6kSMGW45z}xe-f%O3uYOp5|pDGm)-}jZEcu=rTk}*WR;CXm9DS$}5KFWN3 zTLTYTDK*Kn1Ka0w_-y_SGJ0tkti4)6dT!Z*?EFgdwE7k`aHt@eyDrk3hdk(;ZGNQD zG!EW>jRhwro^g}BOGN(2LQI(iYh@sWE(WK-y-pMQw*@lSABQ6IHUl%&Z-LkGWVZ89 z6f?7ZCT$d)V|ZmJp>pz@>h)vBV46}Q@tCtta1I;EbZ3ci*TrpxRP8u)7Y~E&^9@06 z>^}Su*htT82=I#5AXw!ivWbcjy5dhk^7_s}MZaeMedrd?uDsY2adR0h-f2=+Vv!;MMZ* zgMt#9)R#>!^vgofEglKV@#KxyI>#=0F$EG6Wx47djhOpt4vzgG!psMz=Yhk#@cHLqXRSBU7=N4)?}hmA{tzlZ4Tm$0f^!Y}h$w~=nDA~k?7Vh~p3xNa z%9Jd|l}#>mo^h<;?6{oSW)T50%ky!sstfyG5Q~vz4Aj;XL+|yw|% zzORmet@}I{=Z#+tolE(&a>_~=$o_%vJu7Lk*bU4K$Yy%MlDsPKfw!O3`4`*6@ad>( zXuWZTOqmplTNdb}Xt@fAx&CIKB{);#DX(D86b>589R{aE#WY{U8Iyc%Fea*;sMl10 zPeV0Pzy6mPs8`Yz>2h?Pwkk}P_(*e}Mqt|S2RL`7F9a^01Fduhsw|eI(;sK!>e;6l zPtCd9lcuxjVii@~ovi`NrXGa+xyK4URK}!(&&h}7*|;RV2*m!1!{Dz+^hWCoa@+Pk zx$lsR*=9~e@6K!DcH0rBEbhP;=gOFs!lM%qk-_-ezuu=OsG^Okaap|e-< z+u|sY8uCZQSw6Tuava1RzezNF9uwAASi|ZiG~mokEPL~t=)YCR z2|avl!gTUrYdGmmEAtLGM0Ny-T=AXOtk<5ilzq;FTIw(M9({@F{1Q+&}M&-);s&sk#!hP3$7#d!J+U&KhV~ z8VaoeGq~-#3m`@AEPKB)1>_Cy6K_s8o9!;xSCHSdx=)llBrzGgt!9F^0|P!?9lQaIlJ|Rxcri26-5z(LrXEIRafE zM0A}}@XeB7^2g>q@jY@M1N1i%jR$&kYUz6rX)(g!BYyOKc|RJPWW<`Rxkbzqb)^obA}mi(X#_ym^?n&e+xawePMl@mEfc90@RWE25M_w zlCsAa;Kr|q(7`+)hhY_H7f-_LxAGwGh&C=N?nXuHFC>(s%3bmFJAG!V!OOKD0Evlv zNEWXPYGhK0;-q`X2;xrk=rRZk$%ILt^`P0mjnz^brVn3sf!)}A6wfV&#WEM@_|a|A-$y+} z!q6ykJe7)9;`*AdfLrDScplckIo(Y1t5o1~MhfvvHNp76SkUrlMXl=REI+FbXJ0-F zyGOT^TD#HM*;-0e6O%ZmU3C4jFpXsfV6e`!JkBi;Tpqi~6`m{yEghUaVt0+xU91Af#-ht@d62;z8 zZ74aF16lb+pyY5HPARIukLhXTL4FjbwH>DSpI^s$HHWbOw-B+OVT{YUUvMD5jJj*A zCAUNc`$fT7+h9c%?XK%2amQoH*E_p#{$>l@Vi1chVF|F@KN}VbdbW(7`lGSR1U$5A zC3wCA7ejN4 zl5m{;8+ws{3Qa!RZU#Mx>H=;!-sma#;vg_zaL$ zCdWZ!@GIF=DgmZ-M{(7oFr!q7C(Ka&ETY&YBiR4lrHhQBaH4E2M7hl%%H9`IJ5Zb} zt#+Q6FA#ytU4MvrT`0NK7KpcL3z_v<1`m(xW{ab5QmJfx#`NqY;{0|D-67OU4_HmX zQ;J;Hdz%_=RnCN#cWQz@$pRwtQVo98h;UaVOvf?TGN_Tm6ZE;<3fzJOoWY%pKch6@ zd3YGUPUX?Ht2E(>ksVkv!rWk=Ho7(ZFLKmp0!P&c8X{-Wj5tqrZJs!dzkC9}X>egU zEwSptLM;-XwvcPSV>*=loQVC-0w207L)q%@>=U6JqUv`HTrW-qorB`Yt!x8JX%RHt zIU8EkqhOvYkLDZ=r(YM{7o3Z&0RN-bz^hAu(`NNJ&+Hl-Hpdv8eZpwx>_fO~@Cns) zuK=ge>rCH35Gr%Vf{KnPl~C@a4@$}*af}^4{9?o`9d(4N>&8%zYY*wJQ`6xO>j=ry zSWw=QK|=e5X@;;eYmlIiyVhpX8SN&Jvr8UqO;(XNkK2i@+&!AGH5JdOUV*-4w`fXi zGn-?zkujQa9_Lo6aCi2u!3Gx*ewXc7I(f`dY+cbs5`C-TLF`ldZQC#mxF-VJ9?vGN zT)^F*ECbg!JfZdUK4e|b1^pH`G9=!UxxTk7Ib)W3K`g$3y+tJ-0zg`FHX1dU@ z#aqcvlR{YUCQL?sh=8{ECiq;tm>$U3hW5GkARg6$-vo~h#hnINGHx4tqc0zK99u|s zKm0>Xm}Lkf>9+0F!EJCEJ{BFnlB|tS!q5R+b2R@>~T!q za}RC3GPkr&2;COjB_|Uj)=W z6GA&R1y~>4NcJDnW>sLrwZekx^;(pr#Ux zwxZH>Znz)JGf{z-31LJlqYA~_JyG3m1=Jm%fJN^l!RF?6l$;Vqh1Z5)x&52!&3q}! zmHz?j5^s>!xD;TPDGU1gZh>s&9VU+(1u_mFVXuY?tL$=u?6!2E3q99RSvm#xZT(L2 z4o%_L_CKHpOD+K^G=k}O`DBC0Z5%Vk1lDYx1|ENJv%!**r03)@w0hCc#)*8UfA8F4 zoMUCd!}<%>DIdnHlG~toxq}fuZjatl65OcdG7@!QB}T7ai}EEyOoU+>WfamNZ+tu1 z-KBsJyC#G8TY-avS3n2~lpJw%>nMjM7!LQ(<$&kx z*?4nRJl>I2g`E@gQPQpix_Ix2>FHV`q9O~IYcw!?#Usj?EP|w$f2sM}n{fSp9rQG6 zehpZ2sc;Rf0k-`NN8hl_W-SvGu zrE6k|f7WV}>|zOFKQGaJR1@zLJ=p54!;jmVjNk68K`%jVW6gwm8vbZ1PU~2OGdPvh z=VTHw{N95bqC4QEdMORBddlVvY{l6L<3JLw;r(UKxS%l*hUW$0@Chvxe(VoHoq{#8 z^D8L4`IE?oou?vpnlSKt72Ekmk#9W|4N)glz~@R5q+}#w*Tujp{nPR!VVoy4Oo#`| zgVJQuxBcuzWeME4Kn`Swz_6fnw zTgBj8^%pd;YlFjFAzW3n9yo7GC#A;N)|Wrf&q z=L}9dwioSB43et1TY$`XPihz5AhP@w=-yyM?;M^9J-Q7*av3QOY~*s(vAY0c(Tm^5D<%YGQ6QdAmz7~TZ0v#W8#z)E_j@eMB6 zREU4Cxbws=wSd;1FuJHM1D)=E!c4~syuE251Q;j40WDJ~evwG~p9$)$ck>`HF$~{y zXko`dEbwPNft%85_(>*|()uh&I^zmgE5?D@$(f)tMUf`e#}NLxdZXylZg&4>L$I(& zV|F*}Me1omG_MPjpNk?P*JKXZdw->!=I`m<%6VM5!b*6Abwu_w113+k@z$$6@>ljN z)ervy>vg8k%oB%jOrbU`9NPy5K@Vt5Kq9KNTqiwbfM|Y{go2~@nLWYZaA)@r4om)l zp`$xMXcq(PZ)=eXja<6cBm~BCbMbOZIqbOO25T!Lu}*p|#PjWV$r1I;m2L9G{#q6c zdiT*oqD|zTy$K#;d&%}W%JA#O6--oTUio9zLckr+C59;koGfYYDdKZ-?^89@czH2DF!6!S*q| zSa-34@#Pf2W0%wPi%%%no|=W#wRWuT$`mm0j%M_}1;NCFr;w?PBjMvLaL}O&>;;3jC~k_?gJ%s@LecfcWE{?-J68x z#$Q2iubN0;}d-$-VW~OUF5+vPrS{qK(C6IjIZP?>ZVysKFe&xBZA*|nQuxc{F~T3tr}ko6lq2E_F;a)|(_KTn%ud46 zjsK6K^9;o5i{rSNQ8o=qWHhZJ>z>~eQIRMOEzwp=LrTNmJ7gz&r;%-*bNi zy6A|bREh_vPOc|R>O(T?lqWOr@E%b*AWNU^*o-?Sw&3?%J(8S~hi|`V^P`>@&uJ6n>nANktMv(v^?J&O3wF?w9*84UK>QeiGL&~#T3cf%H^y~ z>J?bE?IpzJouFuPk_LK;kmNa$uI(!_xYWr}};oU@3at*lG_LAx_9XvCZ2fzC=F#DP>gvykg=PgmZBEg*lxM!3^5 z4b=R*agNqYqS&_)^q+Bk!sQ#tVOvK`9f+kR7Ws6u)ILlTF2Y&O<7i-^GPVmRv4hEn z*?;3a$hbRI%%y9~$SGSx_%Qt)zd6nfo*2HTf~kJ=PkJ(!)%#06ynjGuz1s{obz> zW{h?fV!X|Q3~00A&flRyBrAr?6Mu?^yH{PIzbXW)Jm3YT8%E_hBljh;<@u=@qc1|0mkmZbGBamr=W` zkhwjniz-xcT7F+Pl_6p5zUW4N#Vr$J`YnSVbKDL-UlbWn9Tz&<5Q-t2l!)uMjnKDC z3UuawqoUREF#U!%D||zn^r9!JayUl>x%0`;IvcVv$BL}_v>3%!3S;QRa{AQkGF45z z&pMLxpt&U%y1yNy4V;#2s5gi@8N1;3gea=GgUbLN@!(tEDmlLkn~gKoBZyfB$qUx$V-cY`Tc zPL&r*BctZC;hRAQbcW@SfX2sof97m{(q&uxC1pYnOelablLD-BDgZQ@`H*V z(EsclVZy^3?3joVo44#9HH!HUA7w?r^a5Gl^zGB=EA?npXh}g+@ixY0U<34aN}=b= zS=bU7gz5DamUvoHuubzIyikmW7ttDEvuGRh%;h2K2miqfjcOnW+=|!vPUM|SFnm1z z(DG7E9FY*81t)rf;Y|NIxG8WzO^Hja%VAMGB<%+e)_6hZ>G5!QWet6xG94TfvY~MG z5T2N2hau{=7)88NA_zDVN zVvmkxvB;k6VM}d3k;J(g?6uhGFgDc+GY9oCQ7jK8W(lFQy)?#&7P8G|3D96Q%um^; z59Z#UAiGsVP=72HdXDwbI>|_yCS67k-=J`bUySd~Wa)?T?l7Y2g;VPFv0r06Jc_tM ztq4eya=IcxnP4v=$4p&r9Gz>u-FSIyJ#ukxY26U5H1!w~|+9FQZBOGd79o z!ApbnVCd0Kj|y>lyo@}YHHTnbm#|>!4+;ElNi7%`&&Sbmig-eFE_M&@V%#sqlBo5+ z;IjEzeA|}C?q#xJ^$jV(^QZuL>V2Kizdj7F@F1Ds@`?S=pn`@h)Q0YbQ^=0VDzxB- z9x1UuhO4%;pj`SEI4>`RZFWaF&)WtDo!n95U^U#kdjc2Ueue85ny8qus9>8x3F|ek z0cNP{!rNFwc;fnw_H53;+T96ws;h|H*&;#t1D3QrH=72QY{ZqJd+1%?J($RK&aaP( z2o~7BVCAh@WW^ilOHGzc%6`enJc@zz0U5Ly?O`5GkrK2_(W9zA0@;g8Q}O<~S{xgn zL!Za^i0Tp&2m0jfKgMnGm9&2w~2aui1tTm=%ly&E|T!(y|6}f5zNns#lJ3*s3#$TR_5#ITNO2s3En^g zwL?%4!Q~|0$bzDD54enp^R5eq$xeYfy64ORvA-O)up#0R7Xjdr0g*OQ}Xs*UP zcD;i3xpt$+oLDkv2jQvw_<)JCROph~iLiLQILFJlLBi(U1MA0r^b)N_y)JdGgEIy` zkAzTE`=#YzU?(ZQwjQ=5O~rmF8)M(sEm2p4`bWXyN+3CM`A~ zcIONG^-3;HYgZ#?O)pqm$0c0GTc0x9T~KLHJ3%RizAmgZPd5?g>x zK|8^I+z4?!-$?EE?Spj#iMYln7i_(@)1e)kXk8Sg`J41$vCS>ADY^-tC0TLLXaWjv zdvA7nr8eArc^|q%vq{E|F7+w5wXEUTyne;%#;?<1clWvH=pFxoVAZOUz+%Eza;F} zH%jBAa7P3WjV1(=`iFKX{yLx8xXB;XYwtqoN)Z9`$pKBA`iYx*DD{3m5x#X*TgquI zgQ)fIK-D*i-deB#`L7abymbI-aYl`B%y)KU@M-qsc1gHnH4(RTZGy~o^Qg>`9N4aDK+9XXWmtilB+i`lmi`Dns@7n^Q% zF()@juv>G3N%P2cG&@{CTo*K=ZR26+cDcyu2bW23)HxFLq#P_V6$FAMQ*eLVKMRK| ztFY(G4xIWeADU5*Ke+6cjmrf`E_O;-W{h#OXtRdg5-5JZMFZwiKj`*B=9{0)=6;}cpM z=i>3e9~Ku>W0)PbMQohUed?RL8x=?=F)U>9JC_f+UJ;CK28J{v`X;J&eu8xW8g?>Y z16J*CAdRdJ7~hq_;s=N6y6wr{DAYqB{2TwT0E1i4pHG!xV&!) zM(fUjzED3*F5d`EDWbU7b%YkaSi=}yi{sBrFov|PEQDSe!t}d~nU1QbWOnaK@K4@b1(DtUH@c$_(C<6SddC zHsLq3vN0c|d=d#UH^=&ayNGIA0>kfW7^g~G;N9nefoSG3~vVYu$A#Se-2x-4$@*?FtweO1v@?k!|r=UMCIxSHuPOF9C(}${d2G2$n6TK z_2e|J{D(9lkUNWB8Kcu&{xSwmPq0Hm6gx8Xc~`Hl!G&5!P<-oRJZO3uuC4dMuW_9) z)9?kq*w_@-uh|Q=%^BoT-Y<#^4PkBl5u9!JgD&m$r04reaNWL%_|7983sskN=z1&gylfZNr6W1=XzN61*mmFo*HC#QK#E^Xqi<&bCM0w-$$NKuxKX% zFV;hK$PXeLzmVu8?}mX139#>xJ6e9!uvoYMKdKcl3l$^!m_z<|S>q!SaJ`Y>hLgAG z#0!$3_j4|+mI?$F#eYo9*(bzxynxFqZDeO7-7l^|FU#+0}pO zJ#5GI%^uRrO^5Jz_eV>=W0hq3!$V*aEP>5-(P(up4jaSrA+I$NSALAZeQwDx-B^wH zz&8)AGc>9A#dPw#Iuw^#CgMZsH0)~r0w<(2iQ~vxt~=TWb8iD+Q37n#>3|jG3(;@! zDIB@H3+~k)hweHjOi)?Eh?@1&xTG$kQ#1IBS+Q{WrzU)ww2b_?X$b+A7r?!;3DDRBk{6iL z*9{5qbg47G=Xzs$mO)fxu7wtoS17lz~;6;q^tHI?PGRBxl2X7?a!h-9abTmMiH>1QC+t3VdXy ztIctme;kxPs)f0giR3Zo50t6~vLWv!Fr$lwnUGM(`L=A*{^ zSEy6=l-5t52zh2*#Po7C9_wtR`g#*dW$rS{URVbe+oyvj%45jVD7q=P78ahnPZFoz zqGI-q^i@qVPLJ9+(^RvQR6bG&%{(OW~y9{j9}~d3gK1Jo@HFL2kha zsgnP|`aO9L-%r+&@cm~n`d}K6ZyM0~X3S!)zB+miO@`LkBapd#28^3{hV0s*N0x^u zKu(cU35Qa#KU4^^_Y}gdo#mEQau3+B5B1phwjG?>QZUuGncN=G1dseo z5`L@@yRJtP&p#W<;^!A=%CtVZahWRoSEEC$BW1zRZ5r}}x$bBZNW*>xuqR$j4gpDF3Grm z>@2nhcoGw%MG(nm!h|71d{f+Q`B6TT86G`Kf3DNR2r(^=DWZs19?iwKTps)Bl~i2a zBuPfCWO-HBykXsAOCo>O2=AZ_mn*1cOjoU?-uy=5RicS9Kc?__WEEDdURn1dL6RLY zV{xsIJUVQih@LJ-utM|` z?u_5?m+DOrg6bQ}Fg0=xlufb!nTslDoZ%=Td#j6LX|IJ@i(nt^9cHSep-QJdhTLhF&ouSFXIk;-EfOxFB zPUiGZ#tAPLK|gJ#Go)C$Gcgg|>b0qh#9lh$5lP?LTw^jt4dB_kB8=M*MkNB*)97_2 zWQ|iX9Xndh28uU;@B$C&nb^+USvZ$Dyn8ph_SqR2Kca)1v&ZOwnuuV)_X#{5od*Ag z-66&Q0CsH^LvNEqG(Bb^tXl4hPTf6N@zD(HJ=fwTz6w3KXfpk;x0sr`N>lTwL~`Zy z1;$yGZ)taaGrZGLflNb3)VutJ-4~k!PdE2)+=(S9FUXN^XOgdTm=mE8T)W{q+R_9;PN~U{vCP8rOV$!p73_ZTg06*7zkQSN2$P!yd zG)x`FeCMH$S2Q$P8-npCCpaC$@Xon}z@2f^Ee(8J!S3ZHVr{ttW6c~OT=ooU_!19d z0g5Q*@`i=!@l+@99H);KkyTsMV1Kp^stAN2G1%1d)wNqB%5fa{S;z@w`o_cD5M{u* zEAi#8bUb=I9+oyef||b~Si2{NCgUxBrjI?5i8NwGo%zIlwl8+vlc6WJc|!CB5%g>y zAnU(;go|g+g5$!oaQEar;(Uyclh+7gAYNs^EDXbWZ(hObd&y)c_>)abm!pzn41$c`&;7r_;;k&`8Es`JtuoIebA&t6<-E*5Zl>i_*%A!9%;XT z?BRJ_20sn8MP6dpj`hU(>MSvs`LOe^%q|$XD zoCWU>=g=XqAmZQC4GaA|At?PDm!V3A*!4;zRW^*{v-}_%|2u;>J0qBnRriSL)7AV{ ze*vB%okE@Hx+kwU18UMfQ{gXf<`G4|Ly*7kTFzOLR%-nE25q~tKs6Efv`>y31+T@Af>PZN@CH(;rL zA&%avz&MPBL#IQ~Z^hTTm96^Fn9kDaWkcl3xk9utRDh%-2j~GYed2uJKT@&h3MO&i zS&2`R@aqzJEcX>+6zxR?rq|}7MfPnHGEo~EH($r+m+N7~eLMMb^bmF!pQZK=YNTwI zHW)7t1e1nOjGSROTq;ZAvdin>SlJvlNJ)ryFItjeGVG{o&@QSXS%snD-)MNsB${*F zk8{oqVOgpd*>~(3{keQ8Xovh`r@cMR?AhJH1TFrAJ&_5l)W4;?)`eYcY3)U9P}IRY z3dOLb(H&?b_w6iAs(W0!4f~G>NchJc8T7d-x+1rcdG-s~VB0Ein+4*-IllM_~TIy!oUmNQ?=DS1g}4QnQv8pWpNd5P13Fk3Tm%}~`T~YuLgKBjf$$A)JYP8loh%O1 zn&&3yuyqP?%bEmzx3@!)h8%r*YCi-dWYFp?6?8tvCAEHS!ZShZK@c86rpabe?~upz z@(MRxJ@N?5Qxq}0Y&zZ`fA}?C7hudp5wc>lnN!(iMD6{2oTXbsubh1eW-m2qo4z>K zh%JYa$1~wjNL>es|08+L64aOLJ*IFvhoOHGvtjXZ#&erFC|w#Sn2`RIc)v@4 zxRH6}_J4_>uVV_Y202cQr39rnW2yLEA&^^fh75nQ5L_EyP8$vn(3r2&1hE|3EaQkN zK7>V7e3>W;<%z?Blbv``I~L^2cH)t>OW5~X=g{2C9%uYLi!a4@q3rx1P^{d5Ba0r= zlt(%kLqlOz(R{e^;~2D>f1>eN0oUDMK+3yf{620bF;CAUY}IEHdF2QFy~YY+?-IPT z@ik<%jDpszNR;kZ!HzATX@G(!&h|gY*h#=@Iy-Cx!VRe-K(N0E*0eJ(mKBA zZDW!ly9n0EUL*7GjNsC_c|NmDTL<|0jmBQZJ{p8lvXgL0_3lI2Y;?xD}Xw~N^>@)sA7JJ--DVLo=@}w5l8m5qu z;c4LAc^OY=O$M=^e0;5Q2p&wCfI0e;firzUOllgMUe?6vet*d9sReLkS1x3ECjvi& zTf_Whnzt_&UafscS8R)6{+y1bK}!mtcSSf2`5l9)o3DVJ+&GZEkO~I_wNPNMK#v{% zi+Z|G@jjQOz4gM1sg|opW#88%W?e9Zos>Xt&S!9RE}`D3x@eXl4lH8>=QI7W=e8+H zQ|`pG;lA{jbTpNISO(_XUO;GlDS~x7uC$B6!L%kYe#dFUkq37irU#-*iDz{4>$`ZAyAdQWytI8C2Wn z2F_WNnHzJwQFr}6B3c%XL%SxUNog`Xp03A#dUGrN=2r;D_eZJfcyFS~Wh_HRI4!C> z1N_b?<9NsC;5&62Zt;u7Hrop{xZ)2~2UihFa1N0JZ*%fiyJr&>|Su&_8;q>IYxGz>)|}}ZG4F{c4XEi z5fs~~g)Ki9Gdmrk$=WD2%CBOB!z5_< zA_1*33UEpB8U0T}3D=5V!FbCoFx(bN)<<|jrQTues@?>bt_0HL9z)chmkMj6{t@?% zJn~HGI2L*2p+)c!I1vr7+-6m3yGxmM`ngZFPnA5 z5tZ+*fU*oTc7|yHcs-s3uVT|-uG1Vg{ZJiUFfbjBRP>-gi~%|SEIM^4pAlVh7H@CZ z$s7$T!reX%7=~w2+Ws!1cwiPBYOAJe*Z)OP4LSUFa6cP!QUlyWm!jc{Oq5e~!^?Ft zQ0j1-CZE_&&2>iTo}DUqQ@59Ss47pGncG4A_aH4k@{eNx{UxgB%NXxej%8loPRHS0 z@LJyuZ^pPhe3v2oJK}_9j3zmme*kSux1&+pI{GMXIc|wPiB_XCa888@Q=lye*G^_& z!LQ~zLB#}En=qt=JdXEa?S3dp7QsFD9B}g6DCSlBHEO}Jy1bm9(#CJ4 zsB*vn0}^}rW7*aiNd6Kdu2Xk^*D;#bEJv0utz#3F<4LaWPda6`1@q+AMe^!+6os4J zaOlT&Oc9r)uK9DQVc7vz8x$><@d`<1!xnlgpbcd5hG|Eih(OV&7kVGOq|vXplff61 z?9gr|v&&|p^Odh;>-R+L7oNkOzT`$^cVt?Y%=3pO0cN;Ji=}r}?$Fb9aYQ2~7lkDJ z;fUc?*zyrT*7g)^diDhVEy+u*lOJ5T3)P&3A%hJC zy{;~f@skcsVyB4as|-9cX#*_1ei2Hu+sJnHNhD3uhrWAQjb8p2!M|o1YIqcoJ9H-R z!-zZ7`6uuzl)52qVm^s$?qsLy_mkhXli4`M1 zF4>FfqPDc_;Su6F{t$?o>p*jVI4aI?;cMxxV~m`xP=|+S@yM0!ocH;jl+2JI5ie`l zzH$kOP)?%Bxw#~;t_vP$Zp5lo0aFt+lkVKs(w6#u{!yg3EdOxU`4=etRS? z4!%d)f7~E@@>b(1a+$>cC;;tFGdeu;5Rv-497>pJ*pr+@;ty@2l@1>8By@h=<*~nb zVXGmJwN1jTQAD4Q+BoNG2u{uYL0uLMK(~$^mF)2X^~f-|_N z2jq#Q8pmegfm^|Byop}e=%&Yn^5U5v4j<_(L$2G<)JJ=b-EqxuHS0E)B~EX-F4glb z_<5%yJ9K6(c{D?pkcK=ker4G1E>;~2qpu8#A}fn z`4~5cI-PpKoGeepUvfIkv{paD7k@=wZsU4_M_XV^h@rqP>I+%EgzMooE6@+y@<^&` z4kIGO!;`CpvGemha4>JC{82kPTX_q-2+E=UtDn=3O-J$4_dPN4B5~KN+o1o?D zT4wI?Mby^h5b^%rO=z_nKED2hwi&f?Y&S=?;@)c}ddDF7zHpGF7;0lB*K59b)|3=A z&jni-eY$8vK28a}f@dZ?Ah&;xQHc;CEakL?Z^9$Yr_nZ?XZNap)#U=RyXN4` z%xgL3UJI1P?tuKANyN}L4`Q!`fLB&9y`R6C`EczHv)F5d&ha{e|B`os?rD18-_ra( zF+Z4<+e(a*Zb3`<8?vNoDm*@VjMPnkU z!<@_(XrfJSIm|VSqiXF3QQ}i2lgh1+=I8Oae8F+}X)!|6BPK%4s#^G+ql+dE3XrgC zElgkjo`iLDz~6y!m_K$A$hP^oukQxq(7f0`;Ev3`}O7 zBtH*qr3D;LNV7DL`Z!fl*}dJgpko1Zahe!BPppBc%E#EU_Yv8=xCVOdZjvRpm!R$D zWf<5W2ZO_MIPiBOJZxS~T?Ss`gLoZe&&u%LQ#qV+A|0F_T4VHqY22KFPbBT?cG5a> z2&R@T$B#no%$98rEc9RVu+mBsCxv8^&(*I{>X$4S>`c4hjuVAd_#y+`oXk^ z-Gmow^61q69Lf5{-hq~=gPBfRDu49$GSWm{Eo^Im&g z8=pxBhi!4(!B%owyPDZ*q(i>F5hZn6O^E-hnN@bhD6?oimKqr1fV_ zBv~BmHe5r^p%x72cZWybn_zL?9GJf*2?rA{TC9F6Pe$*B;qGu(Y)xHBipC)%4%o26 z%44kRuk}zk^)1l`3Gy+f3YOa5W{vdqam$K#Y`=~JaewCm?_b_!`8hA?XRbGYVl*C5 zZ$2n1w(`TCh|mY=Su{H?lZyD-h(fypTAXfh-oSo9Z`df2%pOdPf zvFbH>w?mbuU&}|UHOdh8o`paaO;l??N}Of(QN@CEQg+A&dBXni{_!^2S$zg}{mDYr z?FCe>&XbvPX+0`(deNarQN$!j7`{s_#g$Q!XyF?UrAZE0AJb+UC8Bwt zuuhqz7B}Pgwmm5AolG?S?cnFu`^Zh)!xf#nDDyUhq?u2|w1!aPt-1{*^iH$kfw@@b zdYgWIe3Z#{KqtXF5!nig@imbguSZa$5N zj+Jun#tV#C-~dg=Z>T}vMKYynBQqM+3TEX>Fcct()7J=C^>1B7H-0}zwHRaL*g0b0 zBS7Z~%c;6}CMnx2M$F$jLd%E`J)I$r`vsB^RB{f_d48nPJS(hGZf3o>9O|&xKf2=P z9*h~uq<)iCu+BXM-rT#&wCCo+2bH}r%Uhp4=^H^j!%YO=!aK3DrGnl_tftj#3h>O` zF37oQNR^$1puGI`I&j-1wXccw2$1j9p=z%Q$S`Iu+M%j%Fe` z+unR^DgDr($ouTw0yoEbl8%PE@Tzku6e(;Xd7b}(s#_&lay1A9!8gD(K$Be)|ALl& z5uveHY~U-Cg@?pyaP;6Nj^P=G2Kh4}C%Fq|XI&+Ci97c!V=*Xw5i|-WlL`4_boy91 zUcFgKzZ!3XeIqw;uD&VtbEpJc>r(t>FN^D1KM~;#^T9G%9xQKcAWpB9&^pu-3a&@6 z=6RRdtwp8~Ad(L+h0-wJu!qsjcuSh{;?Y&3jM%$;0iOevShaftNS8GcZ4WI}TKJh= z<`hAmT_~pa5AkV5WiYYw_9O467t{XkUM4#F1coIz!>l?{xa8dpR&g9_(d07zTzsFK zceVsxgdc!2Pv&Dh$9eSTi$RitC@f%P05w`^A3-Qz8)v2T)Y7H zT3E5$>YVWGur@wn(@7(@FH7mIgWUs<>8i3y@-U~1I>ajo#)~^*tg;Qcu%M0e?Y-a( z{~WmQX<_xB3gL4#6Jmb64QgXb;f&uL{K-$py}t;rIklK$c8o_Gt|MP^@{GBPl^WUX z-T?+Z24v>MbQliZ1+^MxDBLpwW?GkVO{*r1rS^h-&pxq-=o|K(ntqMNd4_7_W*e7jeo_a1j5r*ABZS<5EqK~-n#=b(;<~Zz&{w3$ zvD)R~Uy2AVcaDVd@3ZNVbF=CDb5BpX<5IA0NgX*oX9Z6>PMN5u9KwVJl95+G5*EFRBOSeNSiUNej9Yk`tx(9t z+o?S84+*80k|LO!|4Om!{$5zpsf}3{>qy}ib+YF^QsY=%5DwE3)XC|9^CMZ>QRYrE z+!nG5zk9G!eI}-Jce+c}FJvC_TdM*mFQ zvGfE^{uoZXVt3&&%?c=(;m=gsYXOsE3#~^ElWI{vjP!HHCG#KR{k0xg!xloE`!(#l z<;UIsn@C~)4ftBsLapt4;ca3#Q5PG)t|Tv*ZybcXl3hvF%-vuhU2r)rpL8s(p!0t1!x1r-KKSNGx;VXe)sIuKV(AScSc$koXCuz^ISt1v z;$TveFO?p@fM`BBNT;gX!Puu*lsMW?>ZfhQMyUm0`>hSL4MWg-XCcwKD zNVf=C(z|0Ouylrp6t{8ul9M3AJrtdUZCzTK! z_RhoxA!&SSUy6!T(#ULIj%79IVVSnGk&M2VM5{D=y1P7{nn`D4f?27h;chXouN}i_ zvVrg>G9Qfk&f_n|3BhXg>!gD5{|R$(MYi8MeR= zy+|*WF2q6g3B29gguwD>5ZwD0OfIecLzfO5L1QKfqck-@n9~8B?%ZYtOI&fXSsVP_ zZAFAVL+JU;<4m686s)P1g(>4x5sd&=KR!s*A|5f;LW^KYl`Ish&Z29(me6fdTQT{& zKJQh#Dwqp&@EXUPY1wTEw?8)FbMZ2g`b87X+M9^Y(;X=3poAC2)>18t-IU+)0j7u# zvLD}zK%wR&LAYBPO*nB9zH9VT$4#>_sPwPJFOmo^WXj=a(ory;D}-NfucYhD#Yo-x z*UaCg0<_r8kUgAEJt17p;_Blr;{K=~h%pNZ_Z4y2HUK|(J;n)_8!caZEC*ALy|;qX zYafQ1LI0m*Xh^$;7YA!_D|Uj?;6V`Va>dKPCli~|aF{?|6V;a$^!4UOa>Q;i(LGa5 zbU1d!zN-o-lb;HGn{$Yj&|JK@upCYs%m@2Gf*P3vM8mb6h*>Pcr}e2MF76KH9TsP{ z)ZB&eht6Qr`GH+?QGr(OaDYsyeMER|Ci-662g*)D zOQvsGDVW#2oW%XegBd5e&Y-y-DJqYLMkzz!KJJW${0hEeayfDSahLMF%xUXEZqC=g zByz$03dcqWgWvmQxwFQ6d~4}WZ|`tJ?=>H(rtm&`+pibCrbYw{zhg$O9u|1Mg81Dh+jf13ZEE5|Cx^99Hk7o;rEHEvnf^aEJ8BnHO^4n z$T6eN@VTk%n7&w4(4WRfGkaN5ckmWH5XLf#BsK63r+Zv7n~Jaglg9`JY2MZIr)gon z10|DfaJ1|n23?4N!iEa`x-A5b#Y_@p$P9vIY%BTRxfMj4A5e>pI(X6VG!*_gM=}PM zQ0un0C}?)Yg-7Oq^XLZ9j#vm6zCGg4>Ny8<79*YNIfqz0TZr;f_lbPkLvm}b6t1`u z$ClAD_@b~J9a9&eZhZ(D{V)NX>y_c~PCm_=w;9i`{{(X9W0>(v)A6r+3uJ4SlJ`;F zaL4i{cb}P!#^bASpe`J`j3ZEetO!N~E69iVt)Mw#09mRsm>FkGFDvb$L)+AGO|J>6 zj9bd}*H^&t;q&y@q*Jh=`3?wt+PVK+!beZVvCHiV+`M)NHw-;tjlOU`WKs^Jq?U=Z zj_!xo{e{%9y$?Qf`l$Y@2J}|Vgt^Zj0{87^&v5#s>5hKlqtT2mFBq8ZvJ9@s?144= z=kgrW(t$C5LES{w!ltFRMBmp7$5$G_>8fy8`P2YDCpA&wE!RnFZaV$eTtf3>EMVsvGq72ZflYn7oJP%ng-;y~-DiqtCN71o4_`t5tZq`gb3e4t z4aQwo?@6}71Lkd(G`-Tf4JzvMKydXE?X~SBYo|EFpJSo8A+w06`Wo z^w-TWqL||h;`%Y5q#*)fH-{{T#&*C>qXalB$iTay1SErMVZc0rk@HQa(!t|+t-fL~ zdo%*VIR0a-qz5(s++wNzcRFb0y~abo{-Btt1a^K}2JZ({u{Guti2U@2e>Y2T-Wh3L zEl(Y%?5tzd-!;MCG6CvemBg9O@R2FY1LNm!Z0UG?eLmWwR7~ z>47;%VbO-wu=U>zlpZo8)*m^~$m$U($@75i35RHq&UJ22S&IKwEP?}{IcBEYN+^t5 zR+l&HCG%1{n>O_Rq*eVNm~X3&(fFgu>>^hiocrfIZZuy8o(=h&hxnJaYzu?r1ZB?0 zEm~YMdid8SYkkFu3Zi(|7VUnE}vjfOc7o@mFGzh$ir7Fc^H9AdLZ#I z`7dfE`MX<<=bbYJYkzs8rrZKJd$kDLEvLgiYZ?6AG#*-moWQ;!+Wc1951bpN57iX1MT)HJVSM!*`X4ukarR(4 z)Rzt7yWnt8TgivAipOO899i73>>eIG--6fj+i3ptP^|HPNwdO&;M%JI*7waG$U3~7uTJ1cO-iBZyDum_f?tRo>c^YO^m-qR>6@k90y(S4=&Q+a++gq=hRj8h(y8pE5}(caQZ*{u^_Ub6a}p z3gC}b2Y%S5#f0}>foGRf`8z&kRF_1Mrt^rB`ml1I#VGP{?f&VUU* z*?$u(KQE!Z9X%-NSW2yAqA;>_D}In|BPI&%p!BPoH0nj+agEv7Z<>M9!G4f`xB!~{ zHQ@5^n>aH(4R_^^kurN({J3o^wNw<4mRl)gl5HpsM~2X&Ea!AwkqkW^DfF;KkFU~=K|Jl%U@Bof`JcC|mGb|fz!x;i;dS^tA_qe!>THA!d1wk4(rX`bW z^Y!R`omTS2s0`n#N?^jy3V1CUL3WvOuK1Jnq{yX$CK>DUWXk$c!?h21pQJ5*H%%t( zgniHDRJ>*a#SC@oHlLC_@jK-0>umbahl4p=#9%}V=Me6EhHF+SAvy%{g)gNty*GL%kR63E8+X0n}8GjK<>b)tBwZ0P-hmKhU5``mbXCT=!0pOHc*4Sb-ZMN6@XAC6C2X$~-wbpal?IH@L)kG&EmPl`XNHJkPs%c$=`|U^2gf|~$ zJi0+zGX>Hs^5Iv&B6wWgMWkhuQT&Yt4cen7n7np~mb6ZQIbJhqc7zI!%`c`MlY$|9 ztvu9~W#K&c8;lSe2|Q_YqR%lFHVfS+JFnfst9P|wxN;M8O}fIal9MBCV1z9_GqLYf z9GtpQ&YfRXaNhK6{Ac@q>^e0M422s}QtdPvc0Z!u6F*Y_ktUin=Pk|o(GJzCI%(lJiv&NMh*KA(7O=bYiem5kPOjt8!< zjqjq%@w#;g^0TKB_E9sd-XY8r`Y(k2)@=zXhn_J%b|_-(Zd*K&A`DKC`NT+Bh4Z6u zzx6jgV9KL)s1WjpuxqbI-Bi%bZvZ;TUiy$B@@{5l?HKK#j#K@h5MPG|V_g ze0k19EBq;$9W+d$i#EdisCe)^rG!StQ?POQ5Z3e-!X|HS_o}~~_0}vVeGYfgPE(lV z+=&L_Aw*kZGN5qjJF@j<6t;4ijhYic`_xc2{(}Jsr5}ebw_6w=SVv5Uk649%(53-i z2E6vR_b|6uA7eKc!NsceR%icJMculWB;`T^lXplJk{>GK#-4rfKwgrW;6{nkTkeip z+5kQFuEB!^kEl|68l3va8Q#tDlrEWQWgQtY$dQ!^M3>xhxWzmszpI zMMUt}&xL0=?=}W6e}FQNOEGgo92z)lf?vHbPpZigH|-sU*_*|A({%QeCodIvb{~Gy zi$#&FaC|u!jF&;ma}&}d(oWqk7ovsvXX@rHMpJ5b)ACY|`BFRyRySWFg_A2$x-bie zqtuD+DUQEAGKXWtG;j>K9?HMx4yUr;FmDXLk>Y)IsJyy>gz7wk9Q`AV?XJ(vr8}0e z1wFuUK^Pde*FdnVfu*8DEQ%UCkOQyJU|vilZM|fN$8{U20%MJGPDyNUuO->HFBQiX z8o+Sw5#srB7mi2$fLU0KTP;I53``)IUU_2B=8zoP0Y+UpijVVGGHG}5>DHr zzqS{Cgo)rCF3Y)8CS$jKDg>R~Ma@p=TevUF#CLmE!W_;m62{cP$93ZHwyuJjES&%= z@-D!6&8+}imqK0ALZUbjPuw%>NcK;bdVZmNcb#0i;NOQBB6%&-}n#mQAMbMQ@>w@iKhpt3i(6 zPGv2e2PN* z=o{Wtqr^S0PauNkXE666f6cTf)lPq`0VGDUt80z~9+T0%E z6^8%_RPH9uj|A)&zeG|Kz><8^Vtl+s2TvF#(682i*sY~$>@v|gFy?1PG)_DL6$N{2 z^-;o$PZJ@>Z=0Y-E(srqajvzf6gq#^Zf11VTwKesT5ZPHQo&OTaJjbuC6)z}Z7D}_ zsT^eouSa6BPzb$ZCm{$QJb_yV=ED8}V_da11by!0KxN)}vTK3??A%ewY@d}y$J0`Y zi=QU*B6*hJvljPT_os=u-t(Hd8CHg>-0s?SbQJE4BTeEOy17_R0oX5LNR z%5M7+`fz#L0O?HrMkC*6KAdrRGZt*IgN513NZ(oFva%D{xdo8Xz7B6r?gnYu2kiYc zS@KuSox1qdppRxL4knc{Ui;+H!E>rWLT57+6bOTMLk>22X0oY#Iq=l!=kBX&A?2VQ zS{iLd`#v7}`gD?-1}!)rAcFfgw^4&kDM7@SAh??~Ol+R+!zCNSF<^{iW#}o9ZB6mC zNuUHb&CPJK{Ul;Lzm=iVm89J08QAPt0|(T*nGOpB2=7QC)%|g>-)V^G1bZ>XXRayKlB%NJ*1N6ifQ;6E3R|o3@64>37 z4=qm1pnr^|BQi4))GXnpk|>w$$nZ*%_QRgR5A@@0hBimNAf~EmxZzL$%<@sdfOR{e zIG5|^=KAB6*fxkz_)Yr6JK^KA-8dxg#@M(IU{=p&G(2&JJlAstuRrCO*Kio@$7>7R8TF_#mw$NY(C=y*1Ag}DRBdOZ*!;5GsAJ_ZEnA@^Ah=H_7`p!Xo1P2NPMQB zj^f6ZD0hAnY?^2Vkry<%{#`J=Z*~G^cvW$};+Lqdu7?g44*|@tLYHbX?p3|Xx~h!B zjl}&RC(EZFT%%#^t1FG-*rD@7wpi8LD^Z>0=V9O5m+Y*!=TNRfpVnV^O;^XslC@*` zP;N30{zg_2&5TO;Ufar^nP*9|t8KteY604b4f7B6&c@-)E|_QZf;i;X;~k5mpvK)H zB{n%G;m(?#WAm$aBpq6p!Jd#ew{(B+=m&L zgaKUKFG5BJKGQ8?rqm&=h%fb}fbHY%GM=Qx!O6PoB*lC&+1KC#5ednxa|e*bv<48^(Jx7PN1eM+A(m6H~ei^A*sR9U~X%QpGC@1G-s4(&Y8fgmE6Hfirc`r z#{|K(!?E1^7Y!f%H}Q{aW;I+!K|_V1kuZ7HGm-q_(d-GUG9OVZ$Vd={&v825K) zaMj%f+*GNW+ln2)J?BiYbf+lK?sp~synGa#W1HdN!Wb@Vy9mFyExO90QmmbM9nZI` z3EHl5mZozQhsr&`<5fQS)tF4EXE#;O5QT(osaO_ghYq^y=s`D0!K7n>G+X>NbFAtv zU&}@qTkK}h4`)i@lyL>h-0mW3Ckk`yWjI6hz%V0hlR&f+f3iaaw^L$e-i3 z;k#GDxBeGUoXAJ!*x*SURC1KerjJSZlSWUR0B@$ z9rHSfiTw>Um9RHKcY8swFzbE}ZqYoa4xImxWN^~ob=H+Mwv9+I4>5?FM!I*L*+_wEK zSRy-)`zDxxdA1#1TZPnYiUy1s3-c0g7D0H$IE^~f4M!)2ksvB8(5ljhUuM#RiuQa= zNf-kguG3|7{}=qx31{x5ox&+cg2;-FEUXVYX;n2i9UmVrCW3}A5atSV4$?BxxbPK#VFLOJ z*+8v`25-aRRs122hojyZ!TQCt6PIPSR4Vv6nO>PrX=XM2d0xYwHPVN%{d3XKq8dv6 ztxT?Nm4YKG9NelpW`gsKdAfpnBH zI{Mn-R28mcsP`}3_M(^Gb5X!cdiKov{ZD96l^lMlKL}NaPEx;NSyVpm0V}4*;MO8j z3{{b&qwxn}qH8k_HcaOEYSl2mVt3%^PZ6G6*h@SeKf8Q%c08V*v|2E|Q<>-GDbMRa z5=IuAm@yOg<%8uRRaCd(cAVBXfSJv?C+*bn+Qy@FWZnqo=cI!Fr)o$TbOwnfJK&0f zFL@9nBk0+5o)H#AGZxcgnDcM0(z+zhf3w|`EcyHtcRhawJ%yX$)1N}---|+6v&?`u z$Fl@)j6cH|xw7(ln|?BDOEUTM+8Ay6V%a&Ze7MTUk%cQ~!JdIdq(E5|dBG0IFBxF3 zX=)2}ul*+FR-#q~mpY+2shgUY)`GutCQglBgSO}N(X+slW>q#(y-%6!$bZT75XS+$ zq_+_*ULWTitEz$vm+C;%+!kU+%*T| zk#&6mu%j`S@BAnb@>Z3S?Ip(q5OWl_opZz-19^BhrI(HkYoUq#ZStSPC5+p;1^ZWB zL)~?jQ2soDY`?U|YQ@+_xXf*VLjLa&*sadH@%J6BZP38u3mu3@hcq6CP2gO=jA|}D z28TZOVz}{9sE9m3W__EB8;e76d`w!hE1!cy_d_eELsf}{awkO!`wEkcbLvATm^HjQ+PfdSE2u%8gE)t0p6Ip z2mWEE(%1_n=pteWP4AC$4tWoJ-(vwbUdze3d%}V@TB@+;xUayu_!ho2SqFZnwXxjA z5ytZPm}QxY5&u4clA8}nkSOO1GV3IMZDLgSY7kzUI>Nl4QN#71f1tB&B7_?!kiM;t zscYz3ezu+(eCyi6Q=5Me?e*GmS85VU6=!mek5!gWeL5{2!w#U^flaVOwULaObN;~5 zMfB|-AJS9zglT>l#+%@=k&3B^^F*FrhLCxw7@hH#8hVStwP0@FQWc9Gmj~c))n}?w zmx}!}B=CvRRkC9LCo;hG+V~&;@=xc~61({#;Bjy{^vJdn(}y3hQypl_=yt1Bw)(K4 zHW;2oe<7(uh4|#bZFq7khWTt2N#g{E;LD8*c<%d7LhY2$JNYDf@f*?fNCK)DrQ#lP z9=~vXK^dt*^d8&*vFiCylk*){#irp{*n04OmyWv=;$gGa47$2qqsT04##G-FJTsmMyS6l-I+q8hDQqPhSAGV$ZT)oPXfpEE z@8i?b-T1=a2H1J$@s3vy%q-19ud_&pUNTs*%o?w7pX(}(0wUuP2n*);k;T?V)c(OO zf?@jjPn``;-8L81x~EbVAMSh~Zv=d;5ZGYc59GE4*nLUF?V$(ZY`6&C+vSHBe^pV< zooVFIp%yY{sGcs-y}*8U*C)K+%_x`ChnHSp>D@C@_nEM-bMBkdUY`xY?{EbF^ph+y){bb z&xAT?qSIU(ag9(lnfo+}8to{<+_>> ") + prompt = rag_args.prompt_structure + + if rag_args.retriever_type == 'dpr': + encoded_input = tokenizer(input_text, padding=True, truncation=True, return_tensors='pt') + with torch.no_grad(): + model_output = embed(**encoded_input) + embeddings = cls_pooling(model_output).numpy() + + _, ids = index.search(embeddings, k=top_k) + docs = [corpus[int(id)] for id in ids[0]] + elif rag_args.retriever_type == 'bm25': + tokenized_query = input_text.split() + docs = bm25.get_top_n(tokenized_query, corpus, n=top_k) + + background = '\n'.join(docs) + background = background[-(model.get_max_length()-len(input_text)-len(prompt)):] + all_input = prompt.format(input_text=input_text, background=background) + input_dataset = dataset.from_dict({ + "type": "text_only", + "instances": [ { "text": all_input } ] + }) + output_dataset = inferencer.inference( + model=model, + dataset=input_dataset, + max_new_tokens=inferencer_args.max_new_tokens, + temperature=inferencer_args.temperature, + ) + output = output_dataset.to_dict()["instances"][0]["text"] + print('Bot:') + print(output) + + +if __name__ == "__main__": + main() diff --git a/contrib/rag/run_rag_evaluation.sh b/contrib/rag/run_rag_evaluation.sh new file mode 100644 index 000000000..34615eefb --- /dev/null +++ b/contrib/rag/run_rag_evaluation.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +if [ ! -d data/MedQA-USMLE ]; then + cd data && ./download.sh MedQA-USMLE && cd - +fi + +lora_args="" +retriever_type="" +corpus_index_path="" +prompt_structure="" + +top_k="" +while [[ $# -ge 1 ]]; do + key="$1" + case ${key} in + -m|--model_name_or_path) + model="$2" + shift + ;; + --lora_model_path) + lora_args="--lora_model_path $2" + shift + ;; + -r|--retriever_type) + retriever_type="--retriever_type $2" + shift + ;; + --corpus_index_path) + corpus_index_path="--corpus_index_path $2" + shift + ;; + --prompt_structure) + prompt_structure="--prompt_structure $2" + shift + ;; + --top_k_retrieve) + top_k="--top_k_retrieve $2" + shift + ;; + *) + echo "error: unknown option \"${key}\"" 1>&2 + exit 1 + esac + shift +done + +CUDA_VISIBLE_DEVICES=0 \ + deepspeed examples/rag_evaluation.py \ + --answer_type medmcqa \ + --model_name_or_path gpt2 \ + --dataset_path data/MedQA-USMLE/validation \ + --deepspeed examples/ds_config.json \ + --inference_batch_size_per_device 1 \ + --metric accuracy + ${retriever_type} \ + ${corpus_index_path} \ + ${prompt_structure} \ + ${top_k} diff --git a/contrib/rag/run_rag_inference.sh b/contrib/rag/run_rag_inference.sh new file mode 100644 index 000000000..73373b1e7 --- /dev/null +++ b/contrib/rag/run_rag_inference.sh @@ -0,0 +1,56 @@ +#!/bin/bash +# An interactive inference script without context history, i.e. the chatbot +# won't have conversation memory. + +model=gpt2 +lora_args="" +retriever_type="" +corpus_index_path="" +prompt_structure="" +top_k="" +while [[ $# -ge 1 ]]; do + key="$1" + case ${key} in + -m|--model_name_or_path) + model="$2" + shift + ;; + --lora_model_path) + lora_args="--lora_model_path $2" + shift + ;; + -r|--retriever_type) + retriever_type="--retriever_type $2" + shift + ;; + --corpus_index_path) + corpus_index_path="--corpus_index_path $2" + shift + ;; + --prompt_structure) + prompt_structure="--prompt_structure $2" + shift + ;; + --top_k_retrieve) + top_k="--top_k_retrieve $2" + shift + ;; + *) + echo "error: unknown option \"${key}\"" 1>&2 + exit 1 + esac + shift +done + +accelerate launch --config_file ../../configs/accelerator_singlegpu_config.yaml \ + rag_inference.py \ + --deepspeed ../../configs/ds_config_chatbot.json \ + --model_name_or_path ${model} \ + --use_accelerator True \ + --max_new_tokens 256 \ + --temperature 1.0 \ + ${lora_args} \ + ${retriever_type} \ + ${corpus_index_path} \ + ${prompt_structure} \ + ${top_k}