Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 03ce1c5

Browse files
committed
🎉 publish
1 parent 13a04d4 commit 03ce1c5

16 files changed

+546
-75
lines changed
20.5 KB
Loading
40.2 KB
Loading
17.8 KB
Loading
15.5 KB
Loading
12.1 KB
Loading
13.6 KB
Loading

README.md

Lines changed: 168 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Simple-SimCSE: A simple PyTorch implementation of SimCSE
1+
# Simple-SimCSE: A simple implementation of SimCSE
22

33
## Introduction
44

@@ -11,39 +11,188 @@ However, [the official implementation of SimCSE](https://github.com/princeton-nl
1111
Of course, the official implementation is great.
1212
However, a simpler implementation would be more helpful in understanding, in particular, it is important for those who are new to research about deep learning or who are just starting out with research about sentence embedding.
1313

14-
Therefore, I implemented a simple version, with minimal abstraction and use of external libraries.
14+
Therefore, We implemented a simple version of SimCSE, with minimal abstraction and use of external libraries.
1515

16-
Only using some basic features of [PyTorch](https://github.com/pytorch/pytorch) and [transformers](https://github.com/huggingface/transformers), I developed code to perform fine-tuning of SimCSE from scratch.
16+
Using some basic features of [PyTorch](https://github.com/pytorch/pytorch) and [transformers](https://github.com/huggingface/transformers), we developed code to perform fine-tuning and evaluation of SimCSE from scratch.
1717

18+
## Instllation & Training
1819

19-
## About Implementation
20+
For development, We used [poetry](https://python-poetry.org/), which is the dependency management and packaging tool for Python.
2021

22+
If you use poetry, you can install necessary packages by following command.
2123

22-
### `download.sh`
23-
24-
### `train.py`
25-
26-
### `sts.py`
24+
```bash
25+
poetry install
26+
```
2727

28-
### `eval.py`
28+
Or, you can install them using `requiments.txt`.
2929

30-
### Misc
30+
```bash
31+
pip install -r requirements.txt
32+
```
3133

34+
The `requirements.txt` is output by following command.
3235

36+
```bash
37+
poetry export -f requirements.txt --output requirements.txt
38+
```
3339

34-
## Instllation
40+
Then, you must execute `download.sh` to download training and evaluation datasets beforehand.
41+
`download.sh` will collect STS and training datasets used in the paper in parallel.
3542

36-
For development, I used [poetry](https://python-poetry.org/), which is the dependency management and packaging tool for Python.
43+
```bash
44+
bash download.sh
45+
```
3746

38-
If you use poetry, you can install some necessary packages by following command.
47+
Finaly, you can train your model as below.
3948

4049
```bash
41-
poetry install
50+
poetry run python train.py
51+
52+
# or
53+
# python train.py
4254
```
4355

44-
Or, in addition, you can install them using `requiments.txt`.
45-
The `requirements.txt` is output by following command.
4656

47-
```bash
48-
poetry export -f requirements.txt --output requirements.txt
57+
## Evaluation (Unsup-SimCSE)
58+
59+
In doing this implementation, we investigated how well the Unsup-SimCSE model trained by this implementation would perform.
60+
61+
We performed fine-tuning of Unsup-SimCSE **50 times** with different random seeds ([0, 49]) with the same dataset and hyperparameters as described in the paper (see `train-original.sh` and `train-multiple.sh`).
62+
We evaluated them on 7 STS tasks (STS12--16, STS Benchmark, SICK-R).
63+
64+
We evaluated models every 250 training steps on the development set of STS-B and keep the best checkpoint for the final evaluation (as with the original paper, described in Appendix A).
65+
66+
### Overall
67+
68+
69+
The table below shows the average performances on each STS task over 50 runs.
70+
71+
| | STS12 | STS13 | STS14 | STS15 | STS16 | STS-B | SICK-R | Avg. |
72+
| ------------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- | --------- |
73+
| BERT-base | **67.64** | 80.42 | 72.68 | 80.25 | 78.12 | 76.28 | 70.35 | 75.11 |
74+
| BERT-large | 65.83 | 80.44 | 71.60 | 81.45 | 76.85 | 75.96 | **72.52** | 74.95 |
75+
| RoBERTa-base | 64.92 | 80.06 | 71.52 | 79.92 | 78.73 | 78.67 | 68.39 | 74.60 |
76+
| RoBERTa-large | 66.46 | **81.93** | **73.93** | **82.71** | **80.58** | **80.57** | 70.82 | **76.71** |
77+
78+
Overall, RoBERTa-large achieved the best average performance, however, the performance of all models compared poorly with the performance of the paper (see Table 5).
79+
80+
The reason is not clear, but the following two points are possible
81+
82+
- There is something wrong with our implementation.
83+
- The hyperparameters used in the paper are tuned with a single random seed (see https://github.com/princeton-nlp/SimCSE/issues/63), so the hyperparameters are not good for our implementation.
84+
- In that case, it would be very difficult to perfectly reproduce the performance of SimCSE with different implementations.
85+
86+
87+
We also found that SimCSE is a bit sensitive to random seeds.
88+
The following figures show histograms and KDE plots of average performance for each fine-tuning.
89+
90+
![histplot](./.github/images/performances/overall-hist.png)![kdeplot](./.github/images/performances/overall-kde.png)
91+
92+
BERT-base and BERT-large show a larger variance in performance compared to RoBERTa-base and RoBERTa-large.
93+
94+
There are many possible reasons for this difference, but the most likely cause is a difference in batch size.
95+
96+
BERT-base and BERT-large are fine-tuned with batch size 64, however, RoBERTa-base and RoBERTa-large are fine-tuned with batch size 512.
97+
98+
We suspect that using smaller batch sizes is likely to result in more unstable performance.
99+
100+
101+
### Details
102+
103+
To encourage future research, we show the average, maximum, and minimum performance for each model for each STS task.
104+
105+
The results of all experiments for each model are in the `results` directory, so we hope they help you.
106+
107+
In addition, to investigate how the performance of each model changes during training, we show the performance transition the development set of STS Benchmark.
108+
109+
110+
#### BERT-base
111+
112+
| | STS12 | STS13 | STS14 | STS15 | STS16 | STS-B | SICK-R | Avg. |
113+
| ---- | ----- | ----- | ----- | ----- | ----- | ----- | ------ | ----- |
114+
| min | 64.64 | 77.91 | 69.36 | 77.60 | 76.26 | 73.99 | 67.80 | 73.16 |
115+
| mean | 67.64 | 80.42 | 72.68 | 80.25 | 78.12 | 76.28 | 70.35 | 75.11 |
116+
| max | 70.24 | 82.89 | 74.93 | 82.80 | 79.95 | 78.29 | 71.99 | 76.70 |
117+
118+
119+
![BERT-base](./.github/images/transitions/bert-base-uncased.png)
120+
121+
The performance of BERT-base is a bit unstable.
122+
This is thought to be largely due to batch size.
123+
124+
On the other hand, it is also worth noting that BERT-base achieved the best performance on STS12.
125+
126+
127+
#### BERT-large
128+
129+
| | STS12 | STS13 | STS14 | STS15 | STS16 | STS-B | SICK-R | Avg. |
130+
| ---- | ----- | ----- | ----- | ----- | ----- | ----- | ------ | ----- |
131+
| min | 61.00 | 73.62 | 65.20 | 76.37 | 69.35 | 69.26 | 67.58 | 70.31 |
132+
| mean | 65.83 | 80.44 | 71.60 | 81.45 | 76.85 | 75.96 | 72.52 | 74.95 |
133+
| max | 69.12 | 84.39 | 75.18 | 83.63 | 79.34 | 79.10 | 75.15 | 77.28 |
134+
135+
136+
![BERT-large](./.github/images/transitions/bert-large-uncased.png)
137+
138+
139+
BERT-large has the most unstable performance of all models (min=70.31, max=77.28).
140+
However, the performance transition seems to be much more stable than BERT-base.
141+
142+
Furthermore, BERT-large achieved the highest performance on SICK-R.
143+
144+
145+
#### RoBERTa-base
146+
147+
| | STS12 | STS13 | STS14 | STS15 | STS16 | STS-B | SICK-R | Avg. |
148+
| ---- | ----- | ----- | ----- | ----- | ----- | ----- | ------ | ----- |
149+
| min | 63.64 | 79.28 | 70.92 | 79.09 | 78.22 | 78.04 | 67.17 | 74.17 |
150+
| mean | 64.92 | 80.06 | 71.52 | 79.92 | 78.73 | 78.67 | 68.39 | 74.60 |
151+
| max | 66.17 | 81.18 | 72.32 | 80.60 | 79.40 | 79.41 | 69.15 | 75.14 |
152+
153+
![RoBERTa-base](./.github/images/transitions/roberta-base.png)
154+
155+
The performance of RoBERTa-base is stable.
156+
157+
The large batch size should be at a slight disadvantage because of the small number of evaluation steps, but RoBERTa-base achieved relatively high performance.
158+
159+
160+
#### RoBERTa-large
161+
162+
| | STS12 | STS13 | STS14 | STS15 | STS16 | STS-B | SICK-R | Avg. |
163+
| ---- | ----- | ----- | ----- | ----- | ----- | ----- | ------ | ----- |
164+
| min | 63.94 | 80.63 | 72.62 | 81.57 | 79.53 | 78.54 | 68.19 | 75.64 |
165+
| mean | 66.46 | 81.93 | 73.93 | 82.71 | 80.58 | 80.57 | 70.82 | 76.71 |
166+
| max | 68.17 | 83.54 | 75.17 | 84.03 | 81.33 | 81.91 | 72.38 | 77.69 |
167+
168+
169+
![RoBERTa-large](./.github/images/transitions/roberta-large.png)
170+
171+
RoBERTa-large achieved the best overall performance and performance changes were stable.
172+
173+
As can be seen from the performance transition between BERT-large and RoBERTa-large, there is a fairly significant difference between the two models.
174+
175+
Whether this is due to hyperparameters or differences in pre-training data/tasks is a matter for further research.
176+
177+
178+
## In this work
179+
180+
SimCSE has further advanced research of sentence embeddings.
181+
The clarity of SimCSE is very impressive, and wide range of applications will be developed in the future.
182+
183+
We hope that this implementation will help in understanding and facilitate future research.
184+
185+
186+
## Citation
187+
188+
```bibtex
189+
@misc{
190+
hayato-tsukagoshi-2022-simple-simcse,
191+
author = {Hayato Tsukagoshi},
192+
title = {Simple-SimCSE: A simple implementation of SimCSE},
193+
year = {2022},
194+
publisher = {GitHub},
195+
journal = {GitHub repository},
196+
howpublished = {\url{https://github.com/hppRC/simple-simcse}}
197+
}
49198
```

eval.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# if you want to perform evaluation only, you can use this snippet ↓
2+
13
import json
24
from pathlib import Path
35
from typing import List

results/bert-base-uncased.csv

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
sts12,sts13,sts14,sts15,sts16,stsb,sick,avg,best-step,best-stsb
2+
66.16695771005516,79.50748652829297,71.55331112150368,80.07930487275637,77.68491225491928,74.91361598920044,69.92975253229812,74.26219157271801,11750,80.05009883773363
3+
66.26317936602125,78.22176616397071,71.87433996740516,78.78145040014111,77.85025614286671,75.06112400603587,70.88629476141153,74.13405868683606,13750,80.96082028473755
4+
69.94117330237958,82.77064846023775,74.93284099415656,80.62464831558664,79.2879498794275,77.79312724974827,71.45713423722468,76.68678891982299,6500,81.97453664267702
5+
66.76044678482099,80.99617959743549,71.11946523461367,79.2895274078802,78.51161199326975,74.88616603101839,67.7962431367603,74.19423431225697,10000,80.21803489668858
6+
67.84092025900632,80.1582817347236,73.42280127118222,80.18693086545613,77.93263895074845,75.91515770830128,71.73249057836681,75.31274590968353,8250,81.42681509031132
7+
65.18592347899028,79.6127690753521,72.19061270487084,80.66881728558691,78.31950669020335,75.41657295020336,70.11912051744554,74.50190324323606,7750,81.16658092495581
8+
67.18309106725641,82.60063144273676,73.57977179271997,80.38734038694724,79.9480116405962,76.10343000881197,70.46090230673455,75.75188266368616,11750,80.4979385421897
9+
65.444442474682,79.34985222851475,69.35828256547444,78.92181312760644,76.69023010216446,73.99313381425934,68.36114182181704,73.15984230493122,4250,79.23891856262696
10+
69.27925575682569,80.46185908399592,73.41432448511395,81.05815844211497,79.1827514063758,77.30697741147681,69.24532366478269,75.70695003581227,11500,82.26869414105192
11+
68.72185148007489,81.20412919195272,73.73814283804026,80.60779660868413,78.97307210152536,76.44088767372277,71.2335827256805,75.84563751709723,4750,81.73368238393424
12+
68.11822535101703,79.62006365113096,72.6230473938506,80.84676014770312,79.29373406062878,78.19686470497493,70.3930352862552,75.58453294222294,10000,81.95210104000536
13+
66.80562557805854,79.1418773655429,72.76185572108753,79.44353304493646,77.22372614187513,76.54405342679344,69.8824833717249,74.54330780714555,10250,81.95994214861945
14+
65.35019828571303,79.9052902300963,72.91413147214982,80.02945672142256,76.25640921525782,74.05461615917828,70.67999827091077,74.1700143363898,6750,80.35662912044782
15+
67.68243702754938,80.02576168031679,72.77980728707657,80.75393266021577,78.15101240458512,77.15024780476656,69.74045544000214,75.1833791863589,11750,81.60145666990807
16+
69.62583078939795,82.56168998039225,74.58829402256949,81.50097566502744,78.21675994527536,77.24263052686975,71.98962406832923,76.53225785683733,12000,82.32620076436775
17+
66.63000441511196,79.61393783035213,72.04079180577256,79.8873858634253,78.16728793904895,77.10563337224059,71.56707347880474,75.00173067210804,10000,81.12595046987003
18+
66.68085835137123,79.49715201557017,72.54194068906823,81.47044583835206,79.21203168026945,76.89186495046904,70.18040667077678,75.21067145655385,7500,81.70771047988976
19+
67.00149253892572,80.00233177268257,72.49167635876643,78.9173412258112,78.2873536622553,75.16197641949682,70.4089435524236,74.61015936148023,7750,80.72621994490142
20+
67.40271232270555,82.89485839611525,73.26802684174788,81.16639403998839,77.87219108308268,77.04788764238701,71.179008437331,75.83301125190826,11000,82.2109457026226
21+
69.86695701217019,80.5217599131089,73.46186753073229,79.83797957508327,78.26129347056697,76.96670580520433,69.7264404795079,75.52042911233912,8000,82.71249836191669
22+
68.85979078781996,81.43768053557419,73.69392091328596,80.42903422005598,78.96000351500898,77.36844532488834,69.91037259620404,75.80846398469107,6250,82.03223098854232
23+
65.49523092903128,77.90893515236787,69.93367427334046,79.4456870865562,77.27944202832602,75.01643283454187,70.52332140929344,73.65753195906531,7000,81.13478428379551
24+
67.4400470036668,81.20369233240262,74.14522227815709,81.02849638461701,78.95559514731282,76.78539684804736,71.1705217331261,75.81842453247567,10750,81.87844387173604
25+
67.32388703368817,81.09625228689603,72.2497726642738,80.18347632467415,77.3956061472832,76.62493383327337,70.83003065047656,75.10056556293789,10750,81.20096546825789
26+
68.23148129395223,78.79819381831568,72.76297210695212,80.54456172326496,77.14955042091925,76.13589471558679,69.47243399230216,74.72786972447045,7750,80.88739656858664
27+
67.33569650891795,78.97927806305036,69.7808503554348,79.13873064831016,76.6517000272114,75.05993970614084,71.01309837648868,73.99418481222202,7000,80.4899585718905
28+
64.641163662889,78.54717093984021,70.56932463626033,77.59628863570191,78.63993989431668,74.75275527701973,71.16864539323586,73.70218406275197,11000,80.47974415203004
29+
68.99553311652724,82.72940251262163,74.61246785281986,82.51542874345382,78.63606127501866,77.71762412268693,71.71365044983645,76.70288115328064,7750,82.77159431017274
30+
65.31038457099989,80.5105862736508,71.9810633066587,79.62711284230224,77.88811527615022,75.24356116704435,71.5265653029073,74.58391267710192,13750,79.6887046381537
31+
67.55532931580993,82.18894728106,73.79500446499145,80.25506959090605,79.07647409998164,76.43149665660529,71.05459542481056,75.76527383345214,7750,81.46270070226497
32+
70.2421156535613,81.71920599772386,73.85350543787412,82.79769329676631,78.74967588859963,78.29396584279151,70.09987787051736,76.53657714111917,10000,83.87317300632702
33+
70.18990907999861,81.29117884020194,73.02596651745607,81.42371284198127,78.68114461493562,76.70415502060483,71.12055626901696,76.06237474059932,11250,82.04715679955454
34+
68.42049551011873,80.74429545614026,73.54989244904802,80.71979216152651,77.10811658046973,76.10088754371813,70.33155686799542,75.28214808128813,8250,81.8627887007899
35+
67.8170435817371,79.99152577144463,73.24029902475093,80.35443223085626,77.5806517867792,76.51802499719635,68.52293636591749,74.86070196552599,5500,81.08876183210612
36+
69.66206282162769,81.42653091894269,74.13448151839879,81.51440719206907,78.81025432846846,77.91333435090787,70.36578861940058,76.2609799642593,11750,82.87263708452977
37+
68.84393257932999,79.69598592610662,73.27967729865412,81.56509506148959,78.72859941806743,75.95077658808712,70.48704110691389,75.50730113980697,12750,80.87245136254325
38+
69.000725542659,81.1733007875713,71.91942485228604,79.33539994209397,77.25600285075085,75.40112207439661,68.64361668240724,74.67565610459499,7500,81.24290250787527
39+
67.55455862249805,80.19627659314207,72.9740281647014,79.47850058592874,78.9759703944144,78.14387264271426,71.64926277767889,75.5674956830111,12000,82.3841486691269
40+
69.1188620454756,82.19396645029423,73.59062742997145,81.44241540476742,76.86978416253582,77.71714261429632,70.53346753657428,75.92375223484501,9000,82.33881055954676
41+
66.17657436235235,79.3408913583053,72.15081272760844,79.79749352628225,77.43663866318869,76.7513620848933,70.54320392082515,74.59956809192222,10500,80.68007009098885
42+
66.52654629722694,79.79760888864269,70.04178787732859,78.66083723481844,77.62410976229275,76.18208914172135,68.53965069326975,73.91037569932864,3750,80.6287065151126
43+
69.14390128631806,80.69357223651656,73.282045405042,81.0302064368027,77.716040220836,76.83422917495776,70.19577466515744,75.5565384893758,10500,82.54661604103384
44+
70.24398439185585,81.00310528936615,73.21607415747185,80.16129258757032,78.27874707498785,76.0927065695595,69.06723730276684,75.43759248193976,3250,81.1873268467461
45+
66.16425771068303,80.32148481092428,71.80379044123485,79.41899316763849,77.10528870179611,74.40269678798327,69.32670982076044,74.07760306300293,14000,79.27557157583101
46+
69.27752742220315,79.19659673788206,72.84378605821391,79.96659539347498,77.42283018014929,75.58198849181119,69.61517788716357,74.8435003101283,4500,81.40084781260454
47+
66.03826621446525,79.62285065490839,72.32907672540018,79.1185649377093,77.66477050573425,75.0292535108525,70.45093652690461,74.32195986799636,11250,80.23456622482301
48+
66.81610934099211,78.56041018649077,71.94176467491363,81.0785452331435,79.17918873894611,77.63800457993622,70.1731598509561,75.05531180076835,10500,82.23668967056672
49+
66.50037772413434,82.88614843098347,74.30230229406864,80.20845201486347,79.78232991035078,75.64015476902112,71.6504037904306,75.85288127626463,8500,80.5085224637052
50+
67.28279416476255,81.126093797972,72.10151812814615,78.86910788279437,78.15565463503832,75.4904153085531,69.79491293715557,74.6886424077746,13000,80.47955020171825
51+
68.08076457021727,78.15810524809976,72.43338167915351,80.18656191889583,76.83354455620783,76.37409430499983,70.85691898560412,74.7033387518826,8000,82.41631773047301

0 commit comments

Comments
 (0)