diff options
author | David Timber <dxdt@dev.snart.me> | 2024-05-04 01:18:58 +0900 |
---|---|---|
committer | David Timber <dxdt@dev.snart.me> | 2024-05-04 01:18:58 +0900 |
commit | 969d9d3efe661a1081878c0c55737c3cd5f2a49c (patch) | |
tree | 2baf7deeab10eb8b1e48482e4dc4e4442f0dcf94 /src | |
parent | b3e1d10646ea5b0bdbfdbdbf8bd5e21487248c5c (diff) |
- Parameterise the "dryrun" config
- Parameterise the "temperature" prompt param
- Fix JWTMarkerExtractor filter logic
- Replace "jti" with "pid" to the marker JWT to combat reply attacks
- Prompt optimisation through JSON output
- Parameterise the response
- Change from yes/no to score response
- Add the score in the comment format
Diffstat (limited to 'src')
-rw-r--r-- | src/okkybot/__init__.py | 3 | ||||
-rw-r--r-- | src/okkybot/__main__.py | 88 |
2 files changed, 65 insertions, 26 deletions
diff --git a/src/okkybot/__init__.py b/src/okkybot/__init__.py index e827599..406750a 100644 --- a/src/okkybot/__init__.py +++ b/src/okkybot/__init__.py @@ -57,7 +57,8 @@ class JWTMarkerExtractor (HTMLParser): for kv in attrs: if kv[0] == "href": u = urlparse(kv[1]) - if u.hostname != "" or u.path != "": continue + if u.hostname or u.path: + continue qs = parse_qs(u.query) self.marker.extend(qs.get("okkybot-marker", [])) diff --git a/src/okkybot/__main__.py b/src/okkybot/__main__.py index e924581..4d754fd 100644 --- a/src/okkybot/__main__.py +++ b/src/okkybot/__main__.py @@ -22,8 +22,6 @@ MAX_POSTS_PER_TOPIC = 50 TARGET_TOPICS = [ "community" ] POST_TOKEN_LIMIT = 5000 # $0.0025 spending limit per post -dryrun = False - def getCache () -> StateCache: try: with open(CACHE_FILENAME) as f: @@ -75,36 +73,40 @@ def fetchPostData (url, s: requests.Session) -> dict[str, Any]: doc = pyjson5.loads(fetchAPIData(url, s)) return doc -def issueMarkerJWT () -> str: +def issueMarkerJWT (pid: Any) -> str: global conf mc = conf["marker"] - id = str(uuid.uuid4()) payload = { "iss": "okkybot", "sub": "marker", - "jti": id + "pid": str(pid) } return jwt.encode(payload, mc["secret"], algorithm = mc["algorithm"]) -def validateMarkerJWT (token: str) -> bool: +def validateMarkerJWT (token: str, pid: int) -> bool: global conf mc = conf["marker"] payload = jwt.decode(token, mc["secret"], algorithms = mc["algorithm"]) - return payload["iss"] == "okkybot" and payload["sub"] == "marker" - -def writeComment (pid: int, result: str, s: requests.Session): - global dryrun + return ( + payload["iss"] == "okkybot" and + payload["sub"] == "marker" and + payload["pid"] == str(pid)) +def writeComment (pid: int, result: dict[str, Any], s: requests.Session): marker_href = '''?okkybot-marker={marker}'''.format( - marker = urllib.parse.quote(issueMarkerJWT())) + marker = urllib.parse.quote(issueMarkerJWT(pid))) text = "" - text += '''<p>킁킁. AI는 이 글이 정치적이라고 생각합니다:</p>''' - text += '''<blockquote><p>{result}</p></blockquote>'''.format( - result = html.escape(result)) + text += '''<p>(킁킁) AI는 이 글이 정치적이라고 생각합니다:</p>''' + text += "<blockquote>" + text += '''<p>점수: {score}/10</p>'''.format( + score = html.escape(str(result["score"]))) + text += '''<p>설명: {msg}</p>'''.format( + msg = html.escape(result.get("explanation", "").strip())) + text += "</blockquote>" text += '''<p><a href="{href}">.</a></p>'''.format(href = marker_href) body = { @@ -118,7 +120,7 @@ def writeComment (pid: int, result: str, s: requests.Session): print({ "action": "comment", "data": body }) - if not dryrun: + if not conf.get("dryrun", False): with s.post( "{api}/comments".format(api = API_ENDPOINT), json = body) as req: @@ -137,23 +139,57 @@ def determineViability (x: list[str]) -> bool: return len(tokens) <= POST_TOKEN_LIMIT def doLLMPrompt (title: str, stripped_body: str): - prompt = '''Is this post politically charged? -Answer Yes or No. Give a short explanation in Korean only if the answer is yes. + llm_params = conf.get("llm_params", {}).get("gpt", { + "temperature": 0, + "seed": 0 + }) + prompt = '''On the scale of 1 to 10, how politically charged is this post? +Rate the post. In Korean, give a short explanation why the post is political only if the score is over 5. TITLE: {title} BODY: {body}'''.format( title = title, body = " ".join(stripped_body)) - messages = [ { "role": "user", "content": prompt } ] + messages = [ + { + "role": "system", + "content": "You are a helpful assistant that determines whether an " + "internet post is politically charged and returns the result in" + " JSON format." + }, + { "role": "user", "content": prompt } + ] + functions = [ + { + "name": "get_result", + "description": "Get the political chargedness analysis result", + "parameters": { + "type": "object", + "properties": { + "score": { + "type": "number", + "description": "The political chargedness score" + }, + "explanation": { + "type": "string", + "description": "The explanation" + } + } + } + } + ] rsp = openai.chat.completions.create( model = "gpt-3.5-turbo", messages = messages, + functions = functions, + function_call = "auto", + temperature = float(llm_params["temperature"]), ) if rsp.choices: - return rsp.choices[0].message.content + return pyjson5.loads(rsp.choices[0].message.function_call.arguments) -def hasMarkerInComments (comments: list[dict[str, Any]]) -> bool: +def hasMarkerInComments (pid, comments: list[dict[str, Any]]) -> bool: ext = JWTMarkerExtractor() for c in comments: @@ -161,7 +197,7 @@ def hasMarkerInComments (comments: list[dict[str, Any]]) -> bool: try: ext.feed(c["text"]) for m in ext.marker: - if validateMarkerJWT(m): + if validateMarkerJWT(m, pid): return True except Exception as e: sys.stderr.write( @@ -184,7 +220,7 @@ def processPost ( combined = stripped.copy() combined.append(title) - if hasMarkerInComments(comments): + if hasMarkerInComments(pid, comments): result = ProcPostResult.MARKER.value elif determineViability(combined): result = doLLMPrompt(title, stripped) @@ -200,9 +236,11 @@ def processPost ( ] }) - if (result[:len(ProcPostResult.YES.value)].lower() == - ProcPostResult.YES.value.lower()): - writeComment(pid, result[len(ProcPostResult.YES.value) + 1:].strip(), s) + + if hasattr(result, "get"): + score = result.get("score", 0) + if score > 5: + writeComment(pid, result, s) def doPost (topic: str, pid, s: requests.Session): url = "{api}/articles/{pid}".format( |