summaryrefslogtreecommitdiff
path: root/shitoutcode/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'shitoutcode/__init__.py')
-rw-r--r--shitoutcode/__init__.py86
1 files changed, 86 insertions, 0 deletions
diff --git a/shitoutcode/__init__.py b/shitoutcode/__init__.py
new file mode 100644
index 0000000..67da767
--- /dev/null
+++ b/shitoutcode/__init__.py
@@ -0,0 +1,86 @@
+import openai
+import pyjson5
+
+
+class Env:
+ def __init__(self):
+ self.model = "gpt-4o"
+ self.output_tokens_min = 250
+ self.output_tokens_max = 500
+ self.temp: float = None
+ self.seed: int = None
+ self.max_tokens = 8000
+
+class SourceFile:
+ def __init__(self):
+ self.name: str = None
+ self.category: str = None
+ self.contents: str = None
+
+class LLMAPIException (Exception): ...
+
+def __do_prompt (lang: str, extra_prompt: str, env: Env):
+ prompt = '''
+Write me a program source code written in %s. Write anything you'd like.
+Use more than %d words, but no more than %d words in the code.''' % (
+ lang, env.output_tokens_min, env.output_tokens_max)
+ if extra_prompt:
+ prompt = ' ' + extra_prompt
+
+ messages = [
+ {
+ "role": "system",
+ "content": "You're a helpful assistant that writes "
+ + "computer program of any kind"
+ },
+ { "role": "user", "content": prompt }
+ ]
+ functions = [
+ {
+ "name": "get_result",
+ "description": "Output the source code",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "cat": {
+ "type": "string",
+ "description": "one word describing program functions"
+ },
+ "filename": {
+ "type": "string",
+ "description": "The source code file name" # TODO: strip unallowed characters in Unix fs
+ },
+ "code": {
+ "type": "string",
+ "description": "The source code"
+ }
+ }
+ }
+ }
+ ]
+
+ return openai.chat.completions.create(
+ model = env.model,
+ messages = messages,
+ functions = functions,
+ function_call = "auto",
+ temperature = env.temp,
+ seed = env.seed,
+ max_tokens = env.max_tokens
+ )
+
+def gen_rand_srccode (lang: str, extra_prompt: str, env: Env) -> SourceFile | str:
+ rsp = __do_prompt(lang, extra_prompt, env)
+ match rsp.choices[0].finish_reason:
+ case 'stop':
+ return rsp.choices[0].message.content
+ case 'function_call':
+ choice = pyjson5.loads(rsp.choices[0].message.function_call.arguments)
+
+ ret = SourceFile()
+ ret.name = choice["filename"]
+ ret.category = choice.get("cat")
+ ret.contents = choice["code"]
+
+ return ret
+ case _: raise LLMAPIException(rsp.choices[0].finish_reason)